1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
|
"""
Tests for the strategy_utils module.
"""
import pytest
from unittest.mock import MagicMock, patch
from typing import Dict, List, Callable
from azure.ai.evaluation.red_team._utils.strategy_utils import (
strategy_converter_map,
get_converter_for_strategy,
get_chat_target,
get_orchestrators_for_attack_strategies
)
from azure.ai.evaluation.red_team._attack_strategy import AttackStrategy
from azure.ai.evaluation.red_team._callback_chat_target import _CallbackChatTarget
from pyrit.prompt_converter import (
PromptConverter, Base64Converter, FlipConverter, MorseConverter
)
from pyrit.prompt_target import PromptChatTarget, OpenAIChatTarget
@pytest.mark.unittest
class TestStrategyConverterMap:
"""Test the strategy_converter_map function."""
def test_strategy_converter_map(self):
"""Test that the strategy converter map contains expected mappings."""
converters = strategy_converter_map()
# Test that all attack strategies have a corresponding converter
for strategy in AttackStrategy:
assert strategy in converters, f"Missing converter for {strategy}"
# Test specific converters
assert converters[AttackStrategy.Baseline] is None
assert isinstance(converters[AttackStrategy.Base64], Base64Converter)
assert isinstance(converters[AttackStrategy.Flip], FlipConverter)
assert isinstance(converters[AttackStrategy.Morse], MorseConverter)
# Test strategy groups
assert isinstance(converters[AttackStrategy.EASY], list)
assert len(converters[AttackStrategy.EASY]) == 3
assert isinstance(converters[AttackStrategy.MODERATE], list)
assert isinstance(converters[AttackStrategy.DIFFICULT], list)
@pytest.mark.unittest
class TestConverterForStrategy:
"""Test the get_converter_for_strategy function."""
def test_get_converter_for_strategy_single(self):
"""Test getting converter for a single strategy."""
converter = get_converter_for_strategy(AttackStrategy.Base64)
assert isinstance(converter, Base64Converter)
# Test strategy with no converter
converter = get_converter_for_strategy(AttackStrategy.Baseline)
assert converter is None
def test_get_converter_for_strategy_list(self):
"""Test getting converters for a list of strategies."""
strategies = [AttackStrategy.Base64, AttackStrategy.Flip]
converters = get_converter_for_strategy(strategies)
assert isinstance(converters, list)
assert len(converters) == 2
assert isinstance(converters[0], Base64Converter)
assert isinstance(converters[1], FlipConverter)
@pytest.mark.unittest
class TestChatTargetFunctions:
"""Test chat target related functions."""
@patch("azure.ai.evaluation.red_team._utils.strategy_utils.OpenAIChatTarget")
def test_get_chat_target_prompt_chat_target(self, mock_openai_chat_target):
"""Test getting chat target from an existing PromptChatTarget."""
mock_target = MagicMock(spec=PromptChatTarget)
result = get_chat_target(mock_target)
assert result == mock_target
# Verify that we don't create a new target
mock_openai_chat_target.assert_not_called()
@patch("azure.ai.evaluation.red_team._utils.strategy_utils.OpenAIChatTarget")
def test_get_chat_target_azure_openai(self, mock_openai_chat_target):
"""Test getting chat target from an Azure OpenAI configuration."""
mock_instance = MagicMock()
mock_openai_chat_target.return_value = mock_instance
# Test with API key
config = {
"azure_deployment": "gpt-35-turbo",
"azure_endpoint": "https://example.openai.azure.com",
"api_key": "test-api-key"
}
result = get_chat_target(config)
mock_openai_chat_target.assert_called_once_with(
model_name="gpt-35-turbo",
endpoint="https://example.openai.azure.com",
api_key="test-api-key",
api_version='2024-06-01'
)
assert result == mock_instance
# Reset mock
mock_openai_chat_target.reset_mock()
# Test with AAD auth
config = {
"azure_deployment": "gpt-35-turbo",
"azure_endpoint": "https://example.openai.azure.com"
}
result = get_chat_target(config)
mock_openai_chat_target.assert_called_once_with(
model_name="gpt-35-turbo",
endpoint="https://example.openai.azure.com",
use_aad_auth=True,
api_version='2024-06-01'
)
@patch("azure.ai.evaluation.red_team._utils.strategy_utils.OpenAIChatTarget")
def test_get_chat_target_openai(self, mock_openai_chat_target):
"""Test getting chat target from an OpenAI configuration."""
mock_instance = MagicMock()
mock_openai_chat_target.return_value = mock_instance
config = {
"model": "gpt-4",
"api_key": "test-api-key"
}
result = get_chat_target(config)
mock_openai_chat_target.assert_called_once_with(
model_name="gpt-4",
endpoint=None,
api_key="test-api-key",
api_version='2024-06-01'
)
# Test with base_url
mock_openai_chat_target.reset_mock()
config = {
"model": "gpt-4",
"api_key": "test-api-key",
"base_url": "https://example.com/api"
}
result = get_chat_target(config)
mock_openai_chat_target.assert_called_once_with(
model_name="gpt-4",
endpoint="https://example.com/api",
api_key="test-api-key",
api_version='2024-06-01'
)
@patch("azure.ai.evaluation.red_team._utils.strategy_utils._CallbackChatTarget")
def test_get_chat_target_callback_function(self, mock_callback_chat_target):
"""Test getting chat target from a callback function with proper signature."""
mock_instance = MagicMock()
mock_callback_chat_target.return_value = mock_instance
def callback_fn(messages, stream, session_state, context):
return {"role": "assistant", "content": "test"}
result = get_chat_target(callback_fn)
mock_callback_chat_target.assert_called_once_with(callback=callback_fn)
assert result == mock_instance
@patch("azure.ai.evaluation.red_team._utils.strategy_utils._CallbackChatTarget")
def test_get_chat_target_simple_function(self, mock_callback_chat_target):
"""Test getting chat target from a simple function without proper signature."""
mock_instance = MagicMock()
mock_callback_chat_target.return_value = mock_instance
def simple_fn(query):
return "test response"
result = get_chat_target(simple_fn)
# Verify that _CallbackChatTarget was called with a function that has been adapted
mock_callback_chat_target.assert_called_once()
assert result == mock_instance
@pytest.mark.unittest
class TestOrchestratorFunctions:
"""Test orchestrator related functions."""
def test_get_orchestrators_for_attack_strategies(self):
"""Test getting orchestrators for attack strategies."""
strategies = [AttackStrategy.Base64, AttackStrategy.Flip]
orchestrators = get_orchestrators_for_attack_strategies(strategies)
assert isinstance(orchestrators, list)
assert len(orchestrators) == 1
assert callable(orchestrators[0])
# Test the orchestrator function returns None (since it's a placeholder)
mock_chat_target = MagicMock()
mock_prompts = ["test prompt"]
mock_converter = MagicMock()
result = orchestrators[0](mock_chat_target, mock_prompts, mock_converter, "test-strategy", "test-risk")
assert result is None
|