File: test_callback_chat_target.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (133 lines) | stat: -rw-r--r-- 4,800 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Unit tests for callback_chat_target module.
"""

import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import asyncio

from pyrit.common import initialize_pyrit, IN_MEMORY

from azure.ai.evaluation.red_team._callback_chat_target import _CallbackChatTarget

initialize_pyrit(memory_db_type=IN_MEMORY)

@pytest.fixture(scope="function")
def mock_callback():
    """Mock callback for tests."""
    return AsyncMock(
        return_value={
            "messages": [
                {"role": "user", "content": "test prompt"},
                {"role": "assistant", "content": "test response"}
            ],
            "stream": False,
            "session_state": None,
            "context": {}
        }
    )


@pytest.fixture(scope="function")
def chat_target(mock_callback):
    """Create a _CallbackChatTarget instance for tests."""
    return _CallbackChatTarget(callback=mock_callback)


@pytest.fixture(scope="function")
def mock_request():
    """Create a mocked request object that mimics PromptRequestResponse from pyrit."""
    request_piece = MagicMock()
    request_piece.conversation_id = "test-id"
    request_piece.converted_value = "test prompt"
    request_piece.converted_value_data_type = "text"
    request_piece.to_chat_message.return_value = MagicMock(
        role="user", content="test prompt"
    )
    
    request = MagicMock()
    request.request_pieces = [request_piece]
    request.response_pieces = []
    
    # Mock the constructor pattern used by _CallbackChatTarget
    response_piece = MagicMock()
    request.from_response = MagicMock(return_value=request)
    return request


@pytest.mark.unittest
class TestCallbackChatTargetInitialization:
    """Test the initialization of _CallbackChatTarget."""

    def test_init(self, mock_callback):
        """Test the initialization of _CallbackChatTarget."""
        target = _CallbackChatTarget(callback=mock_callback)
        
        assert target._callback == mock_callback
        assert target._stream is False
        
        # Test with stream=True
        target_with_stream = _CallbackChatTarget(callback=mock_callback, stream=True)
        assert target_with_stream._stream is True


@pytest.mark.unittest
class TestCallbackChatTargetPrompts:
    """Test _CallbackChatTarget prompt handling."""
    
    @pytest.mark.asyncio
    async def test_send_prompt_async(self, chat_target, mock_request, mock_callback):
        """Test send_prompt_async method."""
        with patch.object(chat_target, "_memory") as mock_memory, \
            patch("azure.ai.evaluation.red_team._callback_chat_target.construct_response_from_request") as mock_construct:
            # Setup memory mock
            mock_memory.get_chat_messages_with_conversation_id.return_value = []
            
            # Setup construct_response mock
            mock_construct.return_value = mock_request
            
            # Call the method
            response = await chat_target.send_prompt_async(prompt_request=mock_request)
            
            # Check that callback was called with correct parameters
            mock_callback.assert_called_once()
            call_args = mock_callback.call_args[1]
            assert call_args["stream"] is False
            assert call_args["session_state"] is None
            assert call_args["context"] is None
            
            # Check memory usage
            mock_memory.get_chat_messages_with_conversation_id.assert_called_once_with(
                conversation_id="test-id"
            )

    def test_validate_request_multiple_pieces(self, chat_target):
        """Test _validate_request with multiple request pieces."""
        mock_req = MagicMock()
        mock_req.request_pieces = [MagicMock(), MagicMock()]  # Two pieces
        
        with pytest.raises(ValueError) as excinfo:
            chat_target._validate_request(prompt_request=mock_req)
        
        assert "only supports a single prompt request piece" in str(excinfo.value)

    def test_validate_request_non_text_type(self, chat_target):
        """Test _validate_request with non-text data type."""
        mock_req = MagicMock()
        mock_piece = MagicMock()
        mock_piece.converted_value_data_type = "image"  # Not text
        mock_req.request_pieces = [mock_piece]
        
        with pytest.raises(ValueError) as excinfo:
            chat_target._validate_request(prompt_request=mock_req)
        
        assert "only supports text prompt input" in str(excinfo.value)


@pytest.mark.unittest
class TestCallbackChatTargetFeatures:
    """Test _CallbackChatTarget feature support."""

    def test_is_json_response_supported(self, chat_target):
        """Test is_json_response_supported method."""
        assert chat_target.is_json_response_supported() is False