File: test_inference.py

package info (click to toggle)
python-azure 20250829%2Bgit-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 756,824 kB
  • sloc: python: 6,224,989; ansic: 804; javascript: 287; makefile: 198; sh: 195; xml: 109
file content (129 lines) | stat: -rw-r--r-- 5,703 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# pylint: disable=line-too-long,useless-suppression
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import pprint

import pytest
from azure.ai.projects import AIProjectClient
from test_base import TestBase, servicePreparer
from openai import OpenAI
from devtools_testutils import recorded_by_proxy, is_live_and_not_recording


# To run all tests in this class, use the following command in the \sdk\ai\azure-ai-projects folder:
# cls & pytest tests\test_inference.py -s
class TestInference(TestBase):

    @classmethod
    def _test_openai_client(cls, client: OpenAI, model_deployment_name: str):
        chat_completions = client.chat.completions.create(
            model=model_deployment_name,
            messages=[
                {
                    "role": "user",
                    "content": "How many feet are in a mile?",
                },
            ],
        )

        print("Raw dump of chat completions object: ")
        pprint.pprint(chat_completions)
        print("Response message: ", chat_completions.choices[0].message.content)
        contains = ["5280", "5,280"]
        assert any(item in chat_completions.choices[0].message.content for item in contains)

        response = client.responses.create(
            model=model_deployment_name,
            input="How many feet are in a mile?",
        )

        print("Raw dump of responses object: ")
        pprint.pprint(response)
        print("Response message: ", response.output_text)
        contains = ["5280", "5,280"]
        assert any(item in response.output_text for item in contains)

    # To run this test, use the following command in the \sdk\ai\azure-ai-projects folder:
    # cls & pytest tests\test_inference.py::TestInference::test_inference -s
    @servicePreparer()
    @pytest.mark.skipif(
        condition=(not is_live_and_not_recording()),
        reason="Skipped because we cannot record network calls with AOAI client",
    )
    @recorded_by_proxy
    def test_inference(self, **kwargs):

        endpoint = kwargs.pop("azure_ai_projects_tests_project_endpoint")
        print("\n=====> Endpoint:", endpoint)

        model_deployment_name = self.test_inference_params["model_deployment_name"]
        api_version = self.test_inference_params["aoai_api_version"]

        with AIProjectClient(
            endpoint=endpoint,
            credential=self.get_credential(AIProjectClient, is_async=False),
        ) as project_client:

            print(
                "[test_inference] Get an authenticated Azure OpenAI client for the parent AI Services resource, and perform a chat completion operation."
            )
            with project_client.get_openai_client(api_version=api_version) as client:
                self._test_openai_client(client, model_deployment_name)

    # To run this test, use the following command in the \sdk\ai\azure-ai-projects folder:
    # cls & pytest tests\test_inference.py::TestInference::test_inference_on_api_key_auth_connection -s
    @servicePreparer()
    @pytest.mark.skipif(
        condition=(not is_live_and_not_recording()),
        reason="Skipped because we cannot record network calls with AOAI client",
    )
    @recorded_by_proxy
    def test_inference_on_api_key_auth_connection(self, **kwargs):

        endpoint = kwargs.pop("azure_ai_projects_tests_project_endpoint")
        print("\n=====> Endpoint:", endpoint)

        connection_name = self.test_inference_params["connection_name_api_key_auth"]
        model_deployment_name = self.test_inference_params["model_deployment_name"]
        api_version = self.test_inference_params["aoai_api_version"]

        with AIProjectClient(
            endpoint=endpoint,
            credential=self.get_credential(AIProjectClient, is_async=False),
        ) as project_client:

            print(
                "[test_inference_on_api_key_auth_connection] Get an authenticated Azure OpenAI client for a connection AOAI service, and perform a chat completion operation."
            )
            with project_client.get_openai_client(api_version=api_version, connection_name=connection_name) as client:
                self._test_openai_client(client, model_deployment_name)

    # To run this test, use the following command in the \sdk\ai\azure-ai-projects folder:
    # cls & pytest tests\test_inference.py::TestInference::test_inference_on_entra_id_auth_connection -s
    @servicePreparer()
    @pytest.mark.skipif(
        condition=(not is_live_and_not_recording()),
        reason="Skipped because we cannot record network calls with AOAI client",
    )
    @recorded_by_proxy
    def test_inference_on_entra_id_auth_connection(self, **kwargs):

        endpoint = kwargs.pop("azure_ai_projects_tests_project_endpoint")
        print("\n=====> Endpoint:", endpoint)

        connection_name = self.test_inference_params["connection_name_entra_id_auth"]
        model_deployment_name = self.test_inference_params["model_deployment_name"]
        api_version = self.test_inference_params["aoai_api_version"]

        with AIProjectClient(
            endpoint=endpoint,
            credential=self.get_credential(AIProjectClient, is_async=False),
        ) as project_client:

            print(
                "[test_inference_on_entra_id_auth_connection] Get an authenticated Azure OpenAI client for a connection AOAI service, and perform a chat completion operation."
            )
            with project_client.get_openai_client(api_version=api_version, connection_name=connection_name) as client:
                self._test_openai_client(client, model_deployment_name)