1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
"""
Unit tests for callback_chat_target module.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import asyncio
from pyrit.common import initialize_pyrit, IN_MEMORY
from azure.ai.evaluation.red_team._callback_chat_target import _CallbackChatTarget
initialize_pyrit(memory_db_type=IN_MEMORY)
@pytest.fixture(scope="function")
def mock_callback():
"""Mock callback for tests."""
return AsyncMock(
return_value={
"messages": [
{"role": "user", "content": "test prompt"},
{"role": "assistant", "content": "test response"}
],
"stream": False,
"session_state": None,
"context": {}
}
)
@pytest.fixture(scope="function")
def chat_target(mock_callback):
"""Create a _CallbackChatTarget instance for tests."""
return _CallbackChatTarget(callback=mock_callback)
@pytest.fixture(scope="function")
def mock_request():
"""Create a mocked request object that mimics PromptRequestResponse from pyrit."""
request_piece = MagicMock()
request_piece.conversation_id = "test-id"
request_piece.converted_value = "test prompt"
request_piece.converted_value_data_type = "text"
request_piece.to_chat_message.return_value = MagicMock(
role="user", content="test prompt"
)
request = MagicMock()
request.request_pieces = [request_piece]
request.response_pieces = []
# Mock the constructor pattern used by _CallbackChatTarget
response_piece = MagicMock()
request.from_response = MagicMock(return_value=request)
return request
@pytest.mark.unittest
class TestCallbackChatTargetInitialization:
"""Test the initialization of _CallbackChatTarget."""
def test_init(self, mock_callback):
"""Test the initialization of _CallbackChatTarget."""
target = _CallbackChatTarget(callback=mock_callback)
assert target._callback == mock_callback
assert target._stream is False
# Test with stream=True
target_with_stream = _CallbackChatTarget(callback=mock_callback, stream=True)
assert target_with_stream._stream is True
@pytest.mark.unittest
class TestCallbackChatTargetPrompts:
"""Test _CallbackChatTarget prompt handling."""
@pytest.mark.asyncio
async def test_send_prompt_async(self, chat_target, mock_request, mock_callback):
"""Test send_prompt_async method."""
with patch.object(chat_target, "_memory") as mock_memory, \
patch("azure.ai.evaluation.red_team._callback_chat_target.construct_response_from_request") as mock_construct:
# Setup memory mock
mock_memory.get_chat_messages_with_conversation_id.return_value = []
# Setup construct_response mock
mock_construct.return_value = mock_request
# Call the method
response = await chat_target.send_prompt_async(prompt_request=mock_request)
# Check that callback was called with correct parameters
mock_callback.assert_called_once()
call_args = mock_callback.call_args[1]
assert call_args["stream"] is False
assert call_args["session_state"] is None
assert call_args["context"] is None
# Check memory usage
mock_memory.get_chat_messages_with_conversation_id.assert_called_once_with(
conversation_id="test-id"
)
def test_validate_request_multiple_pieces(self, chat_target):
"""Test _validate_request with multiple request pieces."""
mock_req = MagicMock()
mock_req.request_pieces = [MagicMock(), MagicMock()] # Two pieces
with pytest.raises(ValueError) as excinfo:
chat_target._validate_request(prompt_request=mock_req)
assert "only supports a single prompt request piece" in str(excinfo.value)
def test_validate_request_non_text_type(self, chat_target):
"""Test _validate_request with non-text data type."""
mock_req = MagicMock()
mock_piece = MagicMock()
mock_piece.converted_value_data_type = "image" # Not text
mock_req.request_pieces = [mock_piece]
with pytest.raises(ValueError) as excinfo:
chat_target._validate_request(prompt_request=mock_req)
assert "only supports text prompt input" in str(excinfo.value)
@pytest.mark.unittest
class TestCallbackChatTargetFeatures:
"""Test _CallbackChatTarget feature support."""
def test_is_json_response_supported(self, chat_target):
"""Test is_json_response_supported method."""
assert chat_target.is_json_response_supported() is False
|