File: test_semantic_reranker_async.py

package info (click to toggle)
python-azure 20251014%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 766,472 kB
  • sloc: python: 6,314,744; ansic: 804; javascript: 287; makefile: 198; sh: 198; xml: 109
file content (148 lines) | stat: -rw-r--r-- 6,028 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.
# cspell:ignore rerank reranker reranking
import json
import unittest
import asyncio

from azure.cosmos.aio import CosmosClient
import azure.cosmos.exceptions as exceptions
import pytest
from azure.identity.aio import DefaultAzureCredential

import test_config


@pytest.mark.semanticReranker
class TestSemanticRerankerAsync(unittest.TestCase):
    """Test to check async semantic reranker behavior."""
    client: CosmosClient = None
    config = test_config.TestConfig
    host = config.host
    TEST_DATABASE_ID = config.TEST_DATABASE_ID
    TEST_CONTAINER_ID = config.TEST_SINGLE_PARTITION_CONTAINER_ID
    TEST_CONTAINER_PARTITION_KEY = config.TEST_CONTAINER_PARTITION_KEY


    @classmethod
    def setUpClass(cls):
        if cls.host == '[YOUR_ENDPOINT_HERE]':
            raise Exception(
                "You must specify your Azure Cosmos account values for "
                "'host' at the top of this class to run the "
                "tests.")

    async def asyncSetUp(self):
        """Async setup for each test."""
        credential = DefaultAzureCredential()
        self.client = CosmosClient(self.host, credential, connection_verify=False)
        self.test_db = await self.client.create_database_if_not_exists(self.TEST_DATABASE_ID)
        self.test_container = await self.test_db.create_container_if_not_exists(
            self.TEST_CONTAINER_ID,
            self.TEST_CONTAINER_PARTITION_KEY
        )

    async def asyncTearDown(self):
        """Async teardown for each test."""
        try:
            await self.test_db.delete_container(self.TEST_CONTAINER_ID)
            await self.client.delete_database(self.TEST_DATABASE_ID)
        except exceptions.CosmosHttpResponseError:
            pass
        finally:
            await self.client.close()

    def test_semantic_reranker_async(self):
        """Test async semantic reranking functionality."""
        async def run_test():
            await self.asyncSetUp()
            try:
                documents = self._get_documents(document_type="string")
                results = await self.test_container.semantic_rerank(
                    reranking_context="What is the capital of France?",
                    documents=documents,
                    semantic_reranking_options={
                        "return_documents": True,
                        "top_k": 10,
                        "batch_size": 32,
                        "sort": True
                    }
                )
                assert len(results["Scores"]) == len(documents)
                assert results["Scores"][0]["document"] == "Paris is the capital of France."

            finally:
                await self.asyncTearDown()
        asyncio.run(run_test())

    def test_semantic_reranker_async_json_documents(self):
        async def run_test():
            await self.asyncSetUp()
            try:
                documents = self._get_documents(document_type="json")
                results = await self.test_container.semantic_rerank(
                    reranking_context="What is the capital of France?",
                    documents=[json.dumps(item) for item in documents],
                    semantic_reranking_options={
                        "return_documents": True,
                        "top_k": 10,
                        "batch_size": 32,
                        "sort": True,
                        "document_type": "json",
                        "target_paths": "text",
                    }
                )

                assert len(results["Scores"]) == len(documents)
                returned_document = json.loads(results["Scores"][0]["document"])
                assert returned_document["text"] == "Paris is the capital of France."
            finally:
                await self.asyncTearDown()
        asyncio.run(run_test())

    def test_semantic_reranker_async_nested_json_documents(self):
        async def run_test():
            await self.asyncSetUp()
            try:
                documents = self._get_documents(document_type="nested_json")
                results = await self.test_container.semantic_rerank(
                    reranking_context="What is the capital of France?",
                    documents=[json.dumps(item) for item in documents],
                    semantic_reranking_options={
                        "return_documents": True,
                        "top_k": 10,
                        "batch_size": 32,
                        "sort": True,
                        "document_type": "json",
                        "target_paths": "info.text",
                    }
                )

                assert len(results["Scores"]) == len(documents)
                returned_document = json.loads(results["Scores"][0]["document"])
                assert returned_document["info"]["text"] == "Paris is the capital of France."
            finally:
                await self.asyncTearDown()
        asyncio.run(run_test())

    def _get_documents(self, document_type: str):
        if document_type == "string":
            return [
                "Berlin is the capital of Germany.",
                "Paris is the capital of France.",
                "Madrid is the capital of Spain."
            ]
        elif document_type == "json":
            return [
                {"id": "1", "text": "Berlin is the capital of Germany."},
                {"id": "2", "text": "Paris is the capital of France."},
                {"id": "3", "text": "Madrid is the capital of Spain."}
            ]
        elif document_type == "nested_json":
            return [
                {"id": "1", "info": {"text": "Berlin is the capital of Germany."}},
                {"id": "2", "info": {"text": "Paris is the capital of France."}},
                {"id": "3", "info": {"text": "Madrid is the capital of Spain."}}
            ]
        else:
            raise ValueError("Unsupported document type")