1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
|
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.
# cspell:ignore rerank reranker reranking
import json
import unittest
import asyncio
from azure.cosmos.aio import CosmosClient
import azure.cosmos.exceptions as exceptions
import pytest
from azure.identity.aio import DefaultAzureCredential
import test_config
@pytest.mark.semanticReranker
class TestSemanticRerankerAsync(unittest.TestCase):
"""Test to check async semantic reranker behavior."""
client: CosmosClient = None
config = test_config.TestConfig
host = config.host
TEST_DATABASE_ID = config.TEST_DATABASE_ID
TEST_CONTAINER_ID = config.TEST_SINGLE_PARTITION_CONTAINER_ID
TEST_CONTAINER_PARTITION_KEY = config.TEST_CONTAINER_PARTITION_KEY
@classmethod
def setUpClass(cls):
if cls.host == '[YOUR_ENDPOINT_HERE]':
raise Exception(
"You must specify your Azure Cosmos account values for "
"'host' at the top of this class to run the "
"tests.")
async def asyncSetUp(self):
"""Async setup for each test."""
credential = DefaultAzureCredential()
self.client = CosmosClient(self.host, credential, connection_verify=False)
self.test_db = await self.client.create_database_if_not_exists(self.TEST_DATABASE_ID)
self.test_container = await self.test_db.create_container_if_not_exists(
self.TEST_CONTAINER_ID,
self.TEST_CONTAINER_PARTITION_KEY
)
async def asyncTearDown(self):
"""Async teardown for each test."""
try:
await self.test_db.delete_container(self.TEST_CONTAINER_ID)
await self.client.delete_database(self.TEST_DATABASE_ID)
except exceptions.CosmosHttpResponseError:
pass
finally:
await self.client.close()
def test_semantic_reranker_async(self):
"""Test async semantic reranking functionality."""
async def run_test():
await self.asyncSetUp()
try:
documents = self._get_documents(document_type="string")
results = await self.test_container.semantic_rerank(
reranking_context="What is the capital of France?",
documents=documents,
semantic_reranking_options={
"return_documents": True,
"top_k": 10,
"batch_size": 32,
"sort": True
}
)
assert len(results["Scores"]) == len(documents)
assert results["Scores"][0]["document"] == "Paris is the capital of France."
finally:
await self.asyncTearDown()
asyncio.run(run_test())
def test_semantic_reranker_async_json_documents(self):
async def run_test():
await self.asyncSetUp()
try:
documents = self._get_documents(document_type="json")
results = await self.test_container.semantic_rerank(
reranking_context="What is the capital of France?",
documents=[json.dumps(item) for item in documents],
semantic_reranking_options={
"return_documents": True,
"top_k": 10,
"batch_size": 32,
"sort": True,
"document_type": "json",
"target_paths": "text",
}
)
assert len(results["Scores"]) == len(documents)
returned_document = json.loads(results["Scores"][0]["document"])
assert returned_document["text"] == "Paris is the capital of France."
finally:
await self.asyncTearDown()
asyncio.run(run_test())
def test_semantic_reranker_async_nested_json_documents(self):
async def run_test():
await self.asyncSetUp()
try:
documents = self._get_documents(document_type="nested_json")
results = await self.test_container.semantic_rerank(
reranking_context="What is the capital of France?",
documents=[json.dumps(item) for item in documents],
semantic_reranking_options={
"return_documents": True,
"top_k": 10,
"batch_size": 32,
"sort": True,
"document_type": "json",
"target_paths": "info.text",
}
)
assert len(results["Scores"]) == len(documents)
returned_document = json.loads(results["Scores"][0]["document"])
assert returned_document["info"]["text"] == "Paris is the capital of France."
finally:
await self.asyncTearDown()
asyncio.run(run_test())
def _get_documents(self, document_type: str):
if document_type == "string":
return [
"Berlin is the capital of Germany.",
"Paris is the capital of France.",
"Madrid is the capital of Spain."
]
elif document_type == "json":
return [
{"id": "1", "text": "Berlin is the capital of Germany."},
{"id": "2", "text": "Paris is the capital of France."},
{"id": "3", "text": "Madrid is the capital of Spain."}
]
elif document_type == "nested_json":
return [
{"id": "1", "info": {"text": "Berlin is the capital of Germany."}},
{"id": "2", "info": {"text": "Paris is the capital of France."}},
{"id": "3", "info": {"text": "Madrid is the capital of Spain."}}
]
else:
raise ValueError("Unsupported document type")
|