File: test_client_tracing_async.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (105 lines) | stat: -rw-r--r-- 4,302 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import azure.ai.inference as sdk
from azure.ai.inference.tracing import AIInferenceInstrumentor

from model_inference_test_base import (
    ModelClientTestBase,
    ServicePreparerChatCompletions,
)

from azure.core.settings import settings
from devtools_testutils.aio import recorded_by_proxy_async
from memory_trace_exporter import MemoryTraceExporter
from gen_ai_trace_verifier import GenAiTraceVerifier
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

CONTENT_TRACING_ENV_VARIABLE = "AZURE_TRACING_GEN_AI_CONTENT_RECORDING_ENABLED"
content_tracing_initial_value = os.getenv(CONTENT_TRACING_ENV_VARIABLE)


# The test class name needs to start with "Test" to get collected by pytest
class TestClientTracingAsync(ModelClientTestBase):

    @classmethod
    def teardown_class(cls):
        if content_tracing_initial_value is not None:
            os.environ[CONTENT_TRACING_ENV_VARIABLE] = content_tracing_initial_value

    # **********************************************************************************
    #
    #                            TRACING TESTS - CHAT COMPLETIONS
    #
    # **********************************************************************************

    def setup_memory_trace_exporter(self) -> MemoryTraceExporter:
        # Setup Azure Core settings to use OpenTelemetry tracing
        settings.tracing_implementation = "OpenTelemetry"
        trace.set_tracer_provider(TracerProvider())
        _ = trace.get_tracer(__name__)
        memoryExporter = MemoryTraceExporter()
        span_processor = SimpleSpanProcessor(memoryExporter)
        trace.get_tracer_provider().add_span_processor(span_processor)
        return span_processor, memoryExporter

    def modify_env_var(self, name, new_value):
        current_value = os.getenv(name)
        os.environ[name] = new_value
        return current_value

    @ServicePreparerChatCompletions()
    @recorded_by_proxy_async
    async def test_chat_completion_async_tracing_content_recording_disabled(self, **kwargs):
        # Make sure code is not instrumented due to a previous test exception
        try:
            AIInferenceInstrumentor().uninstrument()
        except RuntimeError as e:
            pass
        self.modify_env_var(CONTENT_TRACING_ENV_VARIABLE, "False")
        client = self._create_async_chat_client(**kwargs)
        model = kwargs.pop("azure_ai_chat_model").lower()
        processor, exporter = self.setup_memory_trace_exporter()
        AIInferenceInstrumentor().instrument()
        _ = await client.complete(
            messages=[
                sdk.models.SystemMessage(content="You are a helpful assistant."),
                sdk.models.UserMessage(content="What is the capital of France?"),
            ],
        )
        processor.force_flush()
        spans = exporter.get_spans_by_name_starts_with("chat ")
        if len(spans) == 0:
            spans = exporter.get_spans_by_name("chat")
        assert len(spans) == 1
        span = spans[0]
        expected_attributes = [
            ("gen_ai.operation.name", "chat"),
            ("gen_ai.system", "az.ai.inference"),
            ("gen_ai.request.model", "chat"),
            ("server.address", ""),
            ("gen_ai.response.id", ""),
            ("gen_ai.response.model", model),
            ("gen_ai.usage.input_tokens", "+"),
            ("gen_ai.usage.output_tokens", "+"),
            ("gen_ai.response.finish_reasons", ("stop",)),
        ]
        attributes_match = GenAiTraceVerifier().check_span_attributes(span, expected_attributes)
        assert attributes_match == True

        expected_events = [
            {
                "name": "gen_ai.choice",
                "attributes": {
                    "gen_ai.system": "az.ai.inference",
                    "gen_ai.event.content": '{"finish_reason": "stop", "index": 0}',
                },
            }
        ]
        events_match = GenAiTraceVerifier().check_span_events(span, expected_events)
        assert events_match == True
        AIInferenceInstrumentor().uninstrument()