File: test_semantic_reranker.py

package info (click to toggle)
python-azure 20251118%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 783,356 kB
  • sloc: python: 6,474,533; ansic: 804; javascript: 287; sh: 205; makefile: 198; xml: 109
file content (121 lines) | stat: -rw-r--r-- 4,867 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# The MIT License (MIT)
# Copyright (c) Microsoft Corporation. All rights reserved.
# cspell:ignore rerank reranker reranking
import json
import unittest

import azure.cosmos.cosmos_client as cosmos_client
import azure.cosmos.exceptions as exceptions
import pytest
from azure.identity import DefaultAzureCredential

import test_config


@pytest.mark.semanticReranker
class TestSemanticReranker(unittest.TestCase):
    """Test to check semantic reranker behavior."""
    client: cosmos_client.CosmosClient = None
    config = test_config.TestConfig
    host = config.host
    TEST_DATABASE_ID = config.TEST_DATABASE_ID
    TEST_CONTAINER_ID = config.TEST_SINGLE_PARTITION_CONTAINER_ID
    TEST_CONTAINER_PARTITION_KEY = config.TEST_CONTAINER_PARTITION_KEY

    @classmethod
    def setUpClass(cls):
        if cls.host == '[YOUR_ENDPOINT_HERE]':
            raise Exception(
                "You must specify your Azure Cosmos account values for "
                "'host' at the top of this class to run the "
                "tests.")

        credential = DefaultAzureCredential()
        cls.client = cosmos_client.CosmosClient(cls.host, credential=credential)
        cls.test_db = cls.client.create_database_if_not_exists(cls.TEST_DATABASE_ID)
        cls.test_container = cls.test_db.create_container_if_not_exists(cls.TEST_CONTAINER_ID,
                                                                        cls.TEST_CONTAINER_PARTITION_KEY)

    @classmethod
    def tearDownClass(cls):
        try:
            cls.test_db.delete_container(cls.TEST_CONTAINER_ID)
            cls.client.delete_database(cls.TEST_DATABASE_ID)
        except exceptions.CosmosHttpResponseError:
            pass

    def test_semantic_reranker(self):
        documents = self._get_documents(document_type="string")
        results = self.test_container.semantic_rerank(
            reranking_context="What is the capital of France?",
            documents=documents,
            semantic_reranking_options={
                "return_documents": True,
                "top_k": 10,
                "batch_size": 32,
                "sort": True
            }
        )

        assert len(results["Scores"]) == len(documents)
        assert results["Scores"][0]["document"] == "Paris is the capital of France."

    def test_semantic_reranker_json_documents(self):
        documents = self._get_documents(document_type="json")
        results = self.test_container.semantic_rerank(
            reranking_context="What is the capital of France?",
            documents=[json.dumps(item) for item in documents],
            semantic_reranking_options={
                "return_documents": True,
                "top_k": 10,
                "batch_size": 32,
                "sort": True,
                "document_type": "json",
                "target_paths": "text",
            }
        )

        assert len(results["Scores"]) == len(documents)
        returned_document = json.loads(results["Scores"][0]["document"])
        assert returned_document["text"] == "Paris is the capital of France."

    def test_semantic_reranker_nested_json_documents(self):
        documents = self._get_documents(document_type="nested_json")
        results = self.test_container.semantic_rerank(
            reranking_context="What is the capital of France?",
            documents=[json.dumps(item) for item in documents],
            semantic_reranking_options={
                "return_documents": True,
                "top_k": 10,
                "batch_size": 32,
                "sort": True,
                "document_type": "json",
                "target_paths": "info.text",
            }
        )

        assert len(results["Scores"]) == len(documents)
        returned_document = json.loads(results["Scores"][0]["document"])
        assert returned_document["info"]["text"] == "Paris is the capital of France."

    def _get_documents(self, document_type: str):
        if document_type == "string":
            return [
                "Berlin is the capital of Germany.",
                "Paris is the capital of France.",
                "Madrid is the capital of Spain."
            ]
        elif document_type == "json":
            return [
                {"id": "1", "text": "Berlin is the capital of Germany."},
                {"id": "2", "text": "Paris is the capital of France."},
                {"id": "3", "text": "Madrid is the capital of Spain."}
            ]
        elif document_type == "nested_json":
            return [
                {"id": "1", "info": {"text": "Berlin is the capital of Germany."}},
                {"id": "2", "info": {"text": "Paris is the capital of France."}},
                {"id": "3", "info": {"text": "Madrid is the capital of Spain."}}
            ]
        else:
            raise ValueError("Unsupported document type")