1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.
# cspell:ignore rerank reranker reranking
import json
import unittest
import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
import pytest
from azure.identity import DefaultAzureCredential
import test_config
@pytest.mark.semanticReranker
class TestSemanticReranker(unittest.TestCase):
"""Test to check semantic reranker behavior."""
client: cosmos_client.CosmosClient = None
config = test_config.TestConfig
host = config.host
TEST_DATABASE_ID = config.TEST_DATABASE_ID
TEST_CONTAINER_ID = config.TEST_SINGLE_PARTITION_CONTAINER_ID
TEST_CONTAINER_PARTITION_KEY = config.TEST_CONTAINER_PARTITION_KEY
@classmethod
def setUpClass(cls):
if cls.host == '[YOUR_ENDPOINT_HERE]':
raise Exception(
"You must specify your Azure Cosmos account values for "
"'host' at the top of this class to run the "
"tests.")
credential = DefaultAzureCredential()
cls.client = cosmos_client.CosmosClient(cls.host, credential=credential)
cls.test_db = cls.client.create_database_if_not_exists(cls.TEST_DATABASE_ID)
cls.test_container = cls.test_db.create_container_if_not_exists(cls.TEST_CONTAINER_ID,
cls.TEST_CONTAINER_PARTITION_KEY)
@classmethod
def tearDownClass(cls):
try:
cls.test_db.delete_container(cls.TEST_CONTAINER_ID)
cls.client.delete_database(cls.TEST_DATABASE_ID)
except exceptions.CosmosHttpResponseError:
pass
def test_semantic_reranker(self):
documents = self._get_documents(document_type="string")
results = self.test_container.semantic_rerank(
reranking_context="What is the capital of France?",
documents=documents,
semantic_reranking_options={
"return_documents": True,
"top_k": 10,
"batch_size": 32,
"sort": True
}
)
assert len(results["Scores"]) == len(documents)
assert results["Scores"][0]["document"] == "Paris is the capital of France."
def test_semantic_reranker_json_documents(self):
documents = self._get_documents(document_type="json")
results = self.test_container.semantic_rerank(
reranking_context="What is the capital of France?",
documents=[json.dumps(item) for item in documents],
semantic_reranking_options={
"return_documents": True,
"top_k": 10,
"batch_size": 32,
"sort": True,
"document_type": "json",
"target_paths": "text",
}
)
assert len(results["Scores"]) == len(documents)
returned_document = json.loads(results["Scores"][0]["document"])
assert returned_document["text"] == "Paris is the capital of France."
def test_semantic_reranker_nested_json_documents(self):
documents = self._get_documents(document_type="nested_json")
results = self.test_container.semantic_rerank(
reranking_context="What is the capital of France?",
documents=[json.dumps(item) for item in documents],
semantic_reranking_options={
"return_documents": True,
"top_k": 10,
"batch_size": 32,
"sort": True,
"document_type": "json",
"target_paths": "info.text",
}
)
assert len(results["Scores"]) == len(documents)
returned_document = json.loads(results["Scores"][0]["document"])
assert returned_document["info"]["text"] == "Paris is the capital of France."
def _get_documents(self, document_type: str):
if document_type == "string":
return [
"Berlin is the capital of Germany.",
"Paris is the capital of France.",
"Madrid is the capital of Spain."
]
elif document_type == "json":
return [
{"id": "1", "text": "Berlin is the capital of Germany."},
{"id": "2", "text": "Paris is the capital of France."},
{"id": "3", "text": "Madrid is the capital of Spain."}
]
elif document_type == "nested_json":
return [
{"id": "1", "info": {"text": "Berlin is the capital of Germany."}},
{"id": "2", "info": {"text": "Paris is the capital of France."}},
{"id": "3", "info": {"text": "Madrid is the capital of Spain."}}
]
else:
raise ValueError("Unsupported document type")
|