File: test_simulator.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (123 lines) | stat: -rw-r--r-- 5,606 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# flake8: noqa: F401
# flake8: noqa: F841

import asyncio
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest

from azure.ai.evaluation._exceptions import EvaluationException
from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialSimulator


@pytest.fixture()
def async_callback():
    async def callback(x):
        return x

    yield callback


@pytest.mark.unittest
class TestSimulator:
    @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url")
    @patch(
        "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections"
    )
    @patch("azure.ai.evaluation.simulator.AdversarialSimulator._simulate_async")
    @patch("azure.ai.evaluation.simulator.AdversarialSimulator._ensure_service_dependencies")
    def test_initialization_with_all_valid_scenarios(
        self,
        mock_ensure_service_dependencies,
        mock_get_content_harm_template_collections,
        mock_simulate_async,
        mock_get_service_discovery_url,
        azure_cred,
    ):
        mock_get_service_discovery_url.return_value = "http://some.url/discovery/"
        mock_simulate_async.return_value = MagicMock()
        mock_get_content_harm_template_collections.return_value = ["t1", "t2", "t3", "t4", "t5", "t6", "t7"]
        mock_ensure_service_dependencies.return_value = True
        azure_ai_project = {
            "subscription_id": "test_subscription",
            "resource_group_name": "test_resource_group",
            "project_name": "test_workspace",
        }
        available_scenarios = [
            AdversarialScenario.ADVERSARIAL_CONVERSATION,
            AdversarialScenario.ADVERSARIAL_QA,
            AdversarialScenario.ADVERSARIAL_SUMMARIZATION,
            AdversarialScenario.ADVERSARIAL_SEARCH,
            AdversarialScenario.ADVERSARIAL_REWRITE,
            AdversarialScenario.ADVERSARIAL_CONTENT_GEN_UNGROUNDED,
            AdversarialScenario.ADVERSARIAL_CONTENT_GEN_GROUNDED,
        ]
        for scenario in available_scenarios:
            simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential=azure_cred)
            assert callable(simulator)
            # simulator(scenario=scenario, max_conversation_turns=1, max_simulation_results=3, target=async_callback)

    @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url")
    @patch(
        "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections"
    )
    def test_simulator_raises_validation_error_with_unsupported_scenario(
        self, _get_content_harm_template_collections, _get_service_discovery_url, azure_cred
    ):
        _get_content_harm_template_collections.return_value = []
        _get_service_discovery_url.return_value = "some-url"
        azure_ai_project = {
            "subscription_id": "test_subscription",
            "resource_group_name": "test_resource_group",
            "project_name": "test_workspace",
        }

        async def callback(x):
            return x

        simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential=azure_cred)
        with pytest.raises(EvaluationException):
            outputs = asyncio.run(
                simulator(
                    scenario="unknown-scenario", max_conversation_turns=1, max_simulation_results=3, target=callback
                )
            )

    @patch("azure.ai.evaluation.simulator._model_tools._rai_client.RAIClient._get_service_discovery_url")
    @patch(
        "azure.ai.evaluation.simulator._model_tools.AdversarialTemplateHandler._get_content_harm_template_collections"
    )
    @patch("azure.ai.evaluation.simulator.AdversarialSimulator._simulate_async")
    @patch("azure.ai.evaluation.simulator.AdversarialSimulator._ensure_service_dependencies")
    def test_initialization_parity_with_evals(
        self,
        mock_ensure_service_dependencies,
        mock_get_content_harm_template_collections,
        mock_simulate_async,
        mock_get_service_discovery_url,
    ):
        mock_get_service_discovery_url.return_value = "http://some.url/discovery/"
        mock_simulate_async.return_value = MagicMock()
        mock_get_content_harm_template_collections.return_value = ["t1", "t2", "t3", "t4", "t5", "t6", "t7"]
        mock_ensure_service_dependencies.return_value = True
        azure_ai_project = {
            "subscription_id": "test_subscription",
            "resource_group_name": "test_resource_group",
            "project_name": "test_workspace",
        }
        available_scenarios = [
            AdversarialScenario.ADVERSARIAL_CONVERSATION,
            AdversarialScenario.ADVERSARIAL_QA,
            AdversarialScenario.ADVERSARIAL_SUMMARIZATION,
            AdversarialScenario.ADVERSARIAL_SEARCH,
            AdversarialScenario.ADVERSARIAL_REWRITE,
            AdversarialScenario.ADVERSARIAL_CONTENT_GEN_UNGROUNDED,
            AdversarialScenario.ADVERSARIAL_CONTENT_GEN_GROUNDED,
        ]
        for scenario in available_scenarios:
            simulator = AdversarialSimulator(azure_ai_project=azure_ai_project, credential="test_credential")
            assert callable(simulator)
            # simulator(scenario=scenario, max_conversation_turns=1, max_simulation_results=3, target=async_callback)