1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
|
"""
Unit tests for red_team.utils.formatting_utils module.
"""
import pytest
import math
import json
from unittest.mock import patch, MagicMock, mock_open
from azure.ai.evaluation.red_team._utils.formatting_utils import (
message_to_dict, get_strategy_name, get_flattened_attack_strategies,
get_attack_success, format_scorecard, is_none_or_nan, list_mean_nan_safe
)
from azure.ai.evaluation.red_team._attack_strategy import AttackStrategy
from pyrit.models import ChatMessage
@pytest.fixture(scope="function")
def mock_chat_message():
"""Create a mock chat message for testing."""
message = MagicMock(spec=ChatMessage)
message.role = "user"
message.content = "test content"
return message
@pytest.mark.unittest
class TestMessageToDict:
"""Test message_to_dict function."""
def test_message_to_dict(self, mock_chat_message):
"""Test conversion of ChatMessage to dictionary."""
result = message_to_dict(mock_chat_message)
assert isinstance(result, dict)
assert result["role"] == "user"
assert result["content"] == "test content"
@pytest.mark.unittest
class TestStrategyNameFunctions:
"""Test strategy name handling functions."""
def test_get_strategy_name_single(self):
"""Test getting strategy name from a single strategy."""
strategy = AttackStrategy.Base64
result = get_strategy_name(strategy)
assert result == str(strategy.value)
def test_get_strategy_name_list(self):
"""Test getting strategy name from a list of strategies."""
strategies = [AttackStrategy.Base64, AttackStrategy.Flip]
result = get_strategy_name(strategies)
expected = f"{strategies[0].value}_{strategies[1].value}"
assert result == expected
@pytest.mark.unittest
class TestAttackStrategyFunctions:
"""Test attack strategy related functions."""
def test_get_flattened_attack_strategies_simple(self):
"""Test flattening a simple list of attack strategies."""
strategies = [AttackStrategy.Base64, AttackStrategy.Flip]
result = get_flattened_attack_strategies(strategies)
# Should include baseline and the original strategies
assert AttackStrategy.Baseline in result
assert AttackStrategy.Base64 in result
assert AttackStrategy.Flip in result
assert len(result) == 3 # Both original strategies + Baseline
def test_get_flattened_attack_strategies_easy(self):
"""Test flattening with EASY strategy."""
strategies = [AttackStrategy.EASY]
result = get_flattened_attack_strategies(strategies)
# Should expand EASY into specific strategies plus baseline
assert AttackStrategy.Base64 in result
assert AttackStrategy.Flip in result
assert AttackStrategy.Morse in result
assert AttackStrategy.Baseline in result
assert AttackStrategy.EASY not in result # EASY should be replaced
def test_get_flattened_attack_strategies_moderate(self):
"""Test flattening with MODERATE strategy."""
strategies = [AttackStrategy.MODERATE]
result = get_flattened_attack_strategies(strategies)
# Should expand MODERATE into specific strategies plus baseline
assert AttackStrategy.Tense in result
assert AttackStrategy.Baseline in result
assert AttackStrategy.MODERATE not in result # MODERATE should be replaced
def test_get_flattened_attack_strategies_difficult(self):
"""Test flattening with DIFFICULT strategy."""
strategies = [AttackStrategy.DIFFICULT]
result = get_flattened_attack_strategies(strategies)
# Should expand DIFFICULT into composed strategies plus baseline
assert AttackStrategy.Baseline in result
assert AttackStrategy.DIFFICULT not in result # DIFFICULT should be replaced
# Should have at least two composed strategies
composed_count = 0
for strategy in result:
if isinstance(strategy, list):
composed_count += 1
assert composed_count >= 1
def test_get_flattened_attack_strategies_duplicates(self):
"""Test that duplicate strategies are removed."""
# Add the same strategy twice
strategies = [AttackStrategy.Base64, AttackStrategy.Base64]
result = get_flattened_attack_strategies(strategies)
# Should only include one instance of each strategy
base64_count = 0
for strategy in result:
if strategy == AttackStrategy.Base64:
base64_count += 1
assert base64_count == 1
@pytest.mark.unittest
class TestScorecardFormatting:
"""Test scorecard formatting functions."""
def test_format_scorecard_empty(self):
"""Test scorecard formatting with empty data."""
scan_result = {
"scorecard": {
"risk_category_summary": [],
"joint_risk_attack_summary": []
}
}
result = format_scorecard(scan_result)
assert "Overall ASR: 0%" in result
def test_format_scorecard_with_data(self):
"""Test scorecard formatting with actual data."""
scan_result = {
"scorecard": {
"risk_category_summary": [{
"overall_asr": 25.5
}],
"joint_risk_attack_summary": [
{
"risk_category": "violence",
"baseline_asr": 10.0,
"easy_complexity_asr": 20.0,
"moderate_complexity_asr": 30.0,
"difficult_complexity_asr": 40.0
}
]
},
"studio_url": "https://example.com/studio"
}
result = format_scorecard(scan_result)
assert "Overall ASR: 25.5%" in result
assert "Violence" in result # Should show capitalized risk category
assert "10.0%" in result # Baseline ASR
assert "20.0%" in result # Easy-complexity ASR
assert "30.0%" in result # Moderate-complexity ASR
assert "40.0%" in result # Difficult-complexity ASR
def test_format_scorecard_partial_data(self):
"""Test scorecard formatting with partial data."""
scan_result = {
"scorecard": {
"risk_category_summary": [{
"overall_asr": 15.0
}],
"joint_risk_attack_summary": [
{
"risk_category": "hate_unfairness",
"baseline_asr": 5.0,
# Missing some complexity levels
"easy_complexity_asr": 15.0
}
]
}
}
result = format_scorecard(scan_result)
assert "Overall ASR: 15.0%" in result
assert "Hate-unfairness" in result # Should show formatted risk category
assert "5.0%" in result # Baseline ASR
assert "15.0%" in result # Easy-complexity ASR
assert "N/A" in result # Should show N/A for missing complexities
@pytest.mark.unittest
class TestNumericalHelpers:
"""Test numerical helper functions."""
def test_is_none_or_nan_with_none(self):
"""Test is_none_or_nan with None value."""
assert is_none_or_nan(None) is True
def test_is_none_or_nan_with_nan(self):
"""Test is_none_or_nan with NaN value."""
assert is_none_or_nan(float('nan')) is True
def test_is_none_or_nan_with_valid_number(self):
"""Test is_none_or_nan with a valid number."""
assert is_none_or_nan(42) is False
assert is_none_or_nan(0) is False
assert is_none_or_nan(3.14) is False
def test_list_mean_nan_safe_normal_list(self):
"""Test list_mean_nan_safe with normal values."""
result = list_mean_nan_safe([1, 2, 3, 4])
assert result == 2.5
def test_list_mean_nan_safe_with_nones(self):
"""Test list_mean_nan_safe with None values."""
result = list_mean_nan_safe([1, None, 3, None])
assert result == 2.0 # Average of [1, 3]
def test_list_mean_nan_safe_with_nans(self):
"""Test list_mean_nan_safe with NaN values."""
result = list_mean_nan_safe([1, float('nan'), 3, float('nan')])
assert result == 2.0 # Average of [1, 3]
def test_list_mean_nan_safe_empty_after_filtering(self):
"""Test list_mean_nan_safe with a list that is empty after filtering."""
result = list_mean_nan_safe([None, float('nan')])
assert result == 0.0 # Default when no valid values
|