File: test_image_embeddings_client_async.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (219 lines) | stat: -rw-r--r-- 10,862 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import json
import azure.ai.inference as sdk
import azure.ai.inference.aio as async_sdk

from model_inference_test_base import (
    ModelClientTestBase,
    ServicePreparerImageEmbeddings,
)

from devtools_testutils.aio import recorded_by_proxy_async
from azure.core.exceptions import ServiceRequestError
from azure.core.credentials import AzureKeyCredential


# The test class name needs to start with "Test" to get collected by pytest
class TestImageEmbeddingsClientAsync(ModelClientTestBase):

    # **********************************************************************************
    #
    #         IMAGE EMBEDDINGS REGRESSION TESTS - NO SERVICE RESPONSE REQUIRED
    #
    # **********************************************************************************

    # Regression test. Send a request that includes all supported types of input objects. Make sure the resulting
    # JSON payload that goes up to the service (including headers) is the correct one after hand-inspection.
    @ServicePreparerImageEmbeddings()  # Not sure why this is needed. It errors out if not present. We don't use the env variables in this test.
    async def test_async_image_embeddings_request_payload(self, **kwargs):
        client = async_sdk.ImageEmbeddingsClient(
            endpoint="http://does.not.exist",
            credential=AzureKeyCredential("key-value"),
            headers={"some_header": "some_header_value"},
            user_agent="MyAppId",
        )
        image_embedding_input = ModelClientTestBase._get_image_embeddings_input()
        for _ in range(2):
            try:
                _ = await client.embed(
                    input=[image_embedding_input],
                    dimensions=2048,
                    encoding_format=sdk.models.EmbeddingEncodingFormat.UBINARY,
                    input_type=sdk.models.EmbeddingInputType.QUERY,
                    model_extras={
                        "key1": 1,
                        "key2": True,
                        "key3": "Some value",
                        "key4": [1, 2, 3],
                        "key5": {"key6": 2, "key7": False, "key8": "Some other value", "key9": [4, 5, 6, 7]},
                    },
                    model="some-model-id",
                    raw_request_hook=self.request_callback,
                )
                await client.close()
                assert False
            except ServiceRequestError as _:
                # The test should throw this exception!
                self._validate_image_embeddings_json_request_payload()
                continue
        await client.close()

    # Regression test. Send a request that includes all supported types of input objects, with embedding settings
    # specified in the constructor. Make sure the resulting JSON payload that goes up to the service
    # is the correct one after hand-inspection.
    @ServicePreparerImageEmbeddings()  # Not sure why this is needed. It errors out if not present. We don't use the env variables in this test.
    async def test_async_image_embeddings_request_payload_with_defaults(self, **kwargs):
        client = async_sdk.ImageEmbeddingsClient(
            endpoint="http://does.not.exist",
            credential=AzureKeyCredential("key-value"),
            headers={"some_header": "some_header_value"},
            user_agent="MyAppId",
            dimensions=2048,
            encoding_format=sdk.models.EmbeddingEncodingFormat.UBINARY,
            input_type=sdk.models.EmbeddingInputType.QUERY,
            model_extras={
                "key1": 1,
                "key2": True,
                "key3": "Some value",
                "key4": [1, 2, 3],
                "key5": {"key6": 2, "key7": False, "key8": "Some other value", "key9": [4, 5, 6, 7]},
            },
            model="some-model-id",
        )
        image_embedding_input = ModelClientTestBase._get_image_embeddings_input()
        for _ in range(2):
            try:
                _ = await client.embed(input=[image_embedding_input], raw_request_hook=self.request_callback)
                await client.close()
                assert False
            except ServiceRequestError as _:
                # The test should throw this exception!
                self._validate_image_embeddings_json_request_payload()
                continue
        await client.close()

    # Regression test. Send a request that includes all supported types of input objects, with embeddings settings
    # specified in the constructor and all of them overwritten in the 'embed' call.
    # Make sure the resulting JSON payload that goes up to the service is the correct one after hand-inspection.
    @ServicePreparerImageEmbeddings()  # Not sure why this is needed. It errors out if not present. We don't use the env variables in this test.
    async def test_async_image_embeddings_request_payload_with_defaults_and_overrides(self, **kwargs):
        client = async_sdk.ImageEmbeddingsClient(
            endpoint="http://does.not.exist",
            credential=AzureKeyCredential("key-value"),
            headers={"some_header": "some_header_value"},
            user_agent="MyAppId",
            dimensions=1024,
            encoding_format=sdk.models.EmbeddingEncodingFormat.UINT8,
            input_type=sdk.models.EmbeddingInputType.DOCUMENT,
            model_extras={
                "hey1": 2,
                "key2": False,
                "key3": "Some other value",
                "key9": "Yet another value",
            },
            model="some-other-model-id",
        )
        image_embedding_input = ModelClientTestBase._get_image_embeddings_input()
        for _ in range(2):
            try:
                _ = await client.embed(
                    input=[image_embedding_input],
                    dimensions=2048,
                    encoding_format=sdk.models.EmbeddingEncodingFormat.UBINARY,
                    input_type=sdk.models.EmbeddingInputType.QUERY,
                    model_extras={
                        "key1": 1,
                        "key2": True,
                        "key3": "Some value",
                        "key4": [1, 2, 3],
                        "key5": {"key6": 2, "key7": False, "key8": "Some other value", "key9": [4, 5, 6, 7]},
                    },
                    model="some-model-id",
                    raw_request_hook=self.request_callback,
                )
                await client.close()
                assert False
            except ServiceRequestError as _:
                # The test should throw this exception!
                self._validate_image_embeddings_json_request_payload()
                continue
        await client.close()

    # **********************************************************************************
    #
    #                      HAPPY PATH SERVICE TESTS - IMAGE EMBEDDINGS
    #
    # **********************************************************************************

    # TODO: At the moment the /info route shows  "model_type": "embedding", so load_client
    # will return an EmbeddingsClient instead of ImageEmbeddingsClient. How can we resolve this?
    # This Cohere model (cohere-embed-v2-english) supports both text embeddings and image embeddings.
    @ServicePreparerImageEmbeddings()
    @recorded_by_proxy_async
    async def test_async_load_image_embeddings_client(self, **kwargs):

        client = await self._load_async_image_embeddings_client(**kwargs)
        assert isinstance(client, async_sdk.EmbeddingsClient)
        assert client._model_info
        response1 = await client.get_model_info()
        self._print_model_info_result(response1)
        self._validate_model_info_result(response1, "embedding")  # TODO: What should this be?
        await client.close()

    # TODO: At the moment the /info route shows  "model_type": "embedding", so load_client
    # will return an EmbeddingsClient instead of ImageEmbeddingsClient. How can we resolve this?
    # This Cohere model (cohere-embed-v2-english) supports both text embeddings and image embeddings.
    @ServicePreparerImageEmbeddings()
    @recorded_by_proxy_async
    async def test_async_get_model_info_on_image_embeddings_client(self, **kwargs):

        client = self._create_async_image_embeddings_client(**kwargs)
        assert not client._model_info  # pylint: disable=protected-access

        response1 = await client.get_model_info()
        assert client._model_info  # pylint: disable=protected-access

        self._print_model_info_result(response1)
        self._validate_model_info_result(response1, "embedding")  # TODO: what should this be?

        # Get the model info again. No network calls should be made here,
        # as the response is cached in the client.
        response2 = await client.get_model_info()
        self._print_model_info_result(response2)
        assert response1 == response2
        await client.close()

    @ServicePreparerImageEmbeddings()
    @recorded_by_proxy_async
    async def test_async_image_embeddings_with_entra_id_auth(self, **kwargs):
        client = self._create_async_image_embeddings_client(key_auth=False, **kwargs)
        image_embedding_input = ModelClientTestBase._get_image_embeddings_input(False)

        # Request image embeddings with default service format (list of floats)
        response1 = await client.embed(input=[image_embedding_input])
        self._print_embeddings_result(response1)
        self._validate_image_embeddings_result(response1)
        assert json.dumps(response1.as_dict(), indent=2) == response1.__str__()
        await client.close()

    @ServicePreparerImageEmbeddings()
    @recorded_by_proxy_async
    async def test_async_image_embeddings(self, **kwargs):
        async with self._create_async_image_embeddings_client(**kwargs) as client:
            image_embedding_input = ModelClientTestBase._get_image_embeddings_input(False)

            # Request image embeddings with default service format (list of floats)
            response1 = await client.embed(input=[image_embedding_input])
            self._print_embeddings_result(response1)
            self._validate_image_embeddings_result(response1)
            assert json.dumps(response1.as_dict(), indent=2) == response1.__str__()

            # Request embeddings as base64 encoded strings
            response2 = await client.embed(
                input=[image_embedding_input], encoding_format=sdk.models.EmbeddingEncodingFormat.BASE64
            )
            self._print_embeddings_result(response2, sdk.models.EmbeddingEncodingFormat.BASE64)
            self._validate_image_embeddings_result(response2, sdk.models.EmbeddingEncodingFormat.BASE64)