File: test_search_client.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (323 lines) | stat: -rw-r--r-- 15,060 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import pytest
from unittest import mock
from azure.core.paging import ItemPaged
from azure.core.credentials import AzureKeyCredential

from azure.search.documents._generated.models import (
    SearchDocumentsResult,
    SearchResult,
)
from azure.search.documents._search_client import SearchPageIterator

from azure.search.documents import (
    IndexDocumentsBatch,
    SearchClient,
    RequestEntityTooLargeError,
    ApiVersion,
)
from azure.search.documents._utils import odata, get_answer_query

CREDENTIAL = AzureKeyCredential(key="test_api_key")

CRUD_METHOD_NAMES = [
    "upload_documents",
    "delete_documents",
    "merge_documents",
    "merge_or_upload_documents",
]

CRUD_METHOD_MAP = dict(zip(CRUD_METHOD_NAMES, ["upload", "delete", "merge", "mergeOrUpload"]))


class Test_odata:
    def test_const(self):
        assert odata("no escapes") == "no escapes"

    def test_numbers(self):
        assert odata("foo eq {foo}", foo=10) == "foo eq 10"

    def test_string(self):
        assert odata("foo eq {foo}", foo="a string") == "foo eq 'a string'"

    def test_mixed(self):
        expected = "foo eq 'a string' and bar le 10"
        out = odata("foo eq {foo} and bar le {bar}", foo="a string", bar=10)
        assert out == expected

    def test_escape_single_quote(self):
        assert odata("foo eq {foo}", foo="a '' str'ing") == "foo eq 'a '''' str''ing'"

    def test_prevent_double_quoting(self):
        assert odata("foo eq '{foo}'", foo="a string") == "foo eq 'a string'"


class TestAnswerQuery:
    def test_no_args(self):
        assert get_answer_query() is None

    def test_query_answer(self):
        assert get_answer_query("query") == "query"

    def test_query_answer_count(self):
        assert get_answer_query("query", 5) == "query|count-5"

    def test_query_answer_threshold(self):
        assert get_answer_query("query", query_answer_threshold=0.5) == "query|threshold-0.5"

    def test_query_answer_count_threshold(self):
        assert get_answer_query("query", 5, 0.5) == "query|count-5,threshold-0.5"


class TestSearchClient:
    def test_init(self):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        assert client._headers == {
            "api-key": "test_api_key",
            "Accept": "application/json;odata.metadata=none",
        }

    def test_credential_roll(self):
        credential = AzureKeyCredential(key="old_api_key")
        client = SearchClient("endpoint", "index name", credential)
        assert client._headers == {
            "api-key": "old_api_key",
            "Accept": "application/json;odata.metadata=none",
        }
        credential.update("new_api_key")
        assert client._headers == {
            "api-key": "new_api_key",
            "Accept": "application/json;odata.metadata=none",
        }

    def test_headers_merge(self):
        credential = AzureKeyCredential(key="test_api_key")
        client = SearchClient("endpoint", "index name", credential)
        orig = {"foo": "bar"}
        result = client._merge_client_headers(orig)
        assert result is not orig
        assert result == {
            "api-key": "test_api_key",
            "Accept": "application/json;odata.metadata=none",
            "foo": "bar",
        }

    def test_repr(self):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        assert repr(client) == "<SearchClient [endpoint={}, index={}]>".format(repr("endpoint"), repr("index name"))

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.count")
    def test_get_document_count(self, mock_count):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        client.get_document_count()
        assert mock_count.called
        assert mock_count.call_args[0] == ()
        assert len(mock_count.call_args[1]) == 1
        assert mock_count.call_args[1]["headers"] == client._headers

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.get")
    def test_get_document(self, mock_get):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        client.get_document("some_key")
        assert mock_get.called
        assert mock_get.call_args[0] == ()
        assert len(mock_get.call_args[1]) == 3
        assert mock_get.call_args[1]["headers"] == client._headers
        assert mock_get.call_args[1]["key"] == "some_key"
        assert mock_get.call_args[1]["selected_fields"] == None

        mock_get.reset()

        client.get_document("some_key", selected_fields="foo")
        assert mock_get.called
        assert mock_get.call_args[0] == ()
        assert len(mock_get.call_args[1]) == 3
        assert mock_get.call_args[1]["headers"] == client._headers
        assert mock_get.call_args[1]["key"] == "some_key"
        assert mock_get.call_args[1]["selected_fields"] == "foo"

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.search_post")
    def test_search_query_argument(self, mock_search_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.search(search_text="search text")
        assert isinstance(result, ItemPaged)
        assert result._page_iterator_class is SearchPageIterator
        search_result = SearchDocumentsResult()
        search_result.results = [SearchResult(additional_properties={"key": "val"})]
        mock_search_post.return_value = search_result
        assert not mock_search_post.called
        next(result)
        assert mock_search_post.called
        assert mock_search_post.call_args[0] == ()
        assert mock_search_post.call_args[1]["search_request"].search_text == "search text"

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.suggest_post")
    def test_suggest_query_argument(self, mock_suggest_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.suggest(search_text="search text", suggester_name="sg")
        assert mock_suggest_post.called
        assert mock_suggest_post.call_args[0] == ()
        assert mock_suggest_post.call_args[1]["headers"] == client._headers
        assert mock_suggest_post.call_args[1]["suggest_request"].search_text == "search text"

    def test_suggest_bad_argument(self):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        with pytest.raises(TypeError) as e:
            client.suggest("bad_query")
            assert str(e) == "Expected a SuggestQuery for 'query', but got {}".format(repr("bad_query"))

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.search_post")
    def test_get_count_reset_continuation_token(self, mock_search_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.search(search_text="search text")
        assert isinstance(result, ItemPaged)
        assert result._page_iterator_class is SearchPageIterator
        search_result = SearchDocumentsResult()
        search_result.results = [SearchResult(additional_properties={"key": "val"})]
        mock_search_post.return_value = search_result
        result.__next__()
        result._first_page_iterator_instance.continuation_token = "fake token"
        result.get_count()
        assert not result._first_page_iterator_instance.continuation_token

    @mock.patch(
        "azure.search.documents._generated.operations._documents_operations.DocumentsOperations.autocomplete_post"
    )
    def test_autocomplete_query_argument(self, mock_autocomplete_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        result = client.autocomplete(search_text="search text", suggester_name="sg")
        assert mock_autocomplete_post.called
        assert mock_autocomplete_post.call_args[0] == ()
        assert mock_autocomplete_post.call_args[1]["headers"] == client._headers
        assert mock_autocomplete_post.call_args[1]["autocomplete_request"].search_text == "search text"

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.count")
    def test_get_document_count_v2020_06_30(self, mock_count):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        client.get_document_count()
        assert mock_count.called
        assert mock_count.call_args[0] == ()
        assert len(mock_count.call_args[1]) == 1
        assert mock_count.call_args[1]["headers"] == client._headers

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.get")
    def test_get_document_v2020_06_30(self, mock_get):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        client.get_document("some_key")
        assert mock_get.called
        assert mock_get.call_args[0] == ()
        assert len(mock_get.call_args[1]) == 3
        assert mock_get.call_args[1]["headers"] == client._headers
        assert mock_get.call_args[1]["key"] == "some_key"
        assert mock_get.call_args[1]["selected_fields"] == None

        mock_get.reset()

        client.get_document("some_key", selected_fields="foo")
        assert mock_get.called
        assert mock_get.call_args[0] == ()
        assert len(mock_get.call_args[1]) == 3
        assert mock_get.call_args[1]["headers"] == client._headers
        assert mock_get.call_args[1]["key"] == "some_key"
        assert mock_get.call_args[1]["selected_fields"] == "foo"

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.search_post")
    def test_search_query_argument_v2020_06_30(self, mock_search_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        result = client.search(search_text="search text")
        assert isinstance(result, ItemPaged)
        assert result._page_iterator_class is SearchPageIterator
        search_result = SearchDocumentsResult()
        search_result.results = [SearchResult(additional_properties={"key": "val"})]
        mock_search_post.return_value = search_result
        assert not mock_search_post.called
        next(result)
        assert mock_search_post.called
        assert mock_search_post.call_args[0] == ()
        assert mock_search_post.call_args[1]["search_request"].search_text == "search text"

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.suggest_post")
    def test_suggest_query_argument_v2020_06_30(self, mock_suggest_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        result = client.suggest(search_text="search text", suggester_name="sg")
        assert mock_suggest_post.called
        assert mock_suggest_post.call_args[0] == ()
        assert mock_suggest_post.call_args[1]["headers"] == client._headers
        assert mock_suggest_post.call_args[1]["suggest_request"].search_text == "search text"

    @mock.patch(
        "azure.search.documents._generated.operations._documents_operations.DocumentsOperations.autocomplete_post"
    )
    def test_autocomplete_query_argument_v2020_06_30(self, mock_autocomplete_post):
        client = SearchClient("endpoint", "index name", CREDENTIAL, api_version=ApiVersion.V2020_06_30)
        result = client.autocomplete(search_text="search text", suggester_name="sg")
        assert mock_autocomplete_post.called
        assert mock_autocomplete_post.call_args[0] == ()
        assert mock_autocomplete_post.call_args[1]["headers"] == client._headers
        assert mock_autocomplete_post.call_args[1]["autocomplete_request"].search_text == "search text"

    def test_autocomplete_bad_argument(self):
        client = SearchClient("endpoint", "index name", CREDENTIAL)
        with pytest.raises(TypeError) as e:
            client.autocomplete("bad_query")
            assert str(e) == "Expected a AutocompleteQuery for 'query', but got {}".format(repr("bad_query"))

    @pytest.mark.parametrize("arg", [[], ["doc1"], ["doc1", "doc2"]], ids=lambda x: str(len(x)) + " docs")
    @pytest.mark.parametrize("method_name", CRUD_METHOD_NAMES)
    def test_add_method(self, arg, method_name):
        with mock.patch.object(SearchClient, "index_documents", return_value=None) as mock_index_documents:
            client = SearchClient("endpoint", "index name", CREDENTIAL)

            method = getattr(client, method_name)
            method(arg, extra="foo")

            assert mock_index_documents.called
            assert len(mock_index_documents.call_args[0]) == 1
            batch = mock_index_documents.call_args[0][0]
            assert isinstance(batch, IndexDocumentsBatch)
            assert all(action.action_type == CRUD_METHOD_MAP[method_name] for action in batch.actions)
            assert [action.additional_properties for action in batch.actions] == arg
            assert mock_index_documents.call_args[1]["headers"] == client._headers
            assert mock_index_documents.call_args[1]["extra"] == "foo"

    @mock.patch("azure.search.documents._generated.operations._documents_operations.DocumentsOperations.index")
    def test_index_documents(self, mock_index):
        client = SearchClient("endpoint", "index name", CREDENTIAL)

        batch = IndexDocumentsBatch()
        actions = batch.add_upload_actions("upload1")
        assert len(actions) == 1
        for x in actions:
            assert x.action_type == "upload"
        actions = batch.add_delete_actions("delete1", "delete2")
        assert len(actions) == 2
        for x in actions:
            assert x.action_type == "delete"
        actions = batch.add_merge_actions(["merge1", "merge2", "merge3"])
        for x in actions:
            assert x.action_type == "merge"
        actions = batch.add_merge_or_upload_actions("merge_or_upload1")
        for x in actions:
            assert x.action_type == "mergeOrUpload"

        client.index_documents(batch, extra="foo")
        assert mock_index.called
        assert mock_index.call_args[0] == ()
        assert mock_index.call_args[1]["headers"] == client._headers
        assert mock_index.call_args[1]["extra"] == "foo"

    def test_request_too_large_error(self):
        with mock.patch.object(
            SearchClient,
            "_index_documents_actions",
            side_effect=RequestEntityTooLargeError("Error"),
        ):
            client = SearchClient("endpoint", "index name", CREDENTIAL)
            batch = IndexDocumentsBatch()
            batch.add_upload_actions("upload1")
            with pytest.raises(RequestEntityTooLargeError):
                client.index_documents(batch, extra="foo")