File: test_models.py

package info (click to toggle)
python-azure 20250829%2Bgit-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 756,824 kB
  • sloc: python: 6,224,989; ansic: 804; javascript: 287; makefile: 198; sh: 195; xml: 109
file content (169 lines) | stat: -rw-r--r-- 7,153 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import pytest
import azure.ai.agents.models as _models


class TestModels:
    """Unit tests for models."""

    _FILES = ["file1", "file2"]
    _DATA_SOURCES = [
        _models.VectorStoreDataSource(
            asset_identifier="azureai:/123", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
        ),
        _models.VectorStoreDataSource(
            asset_identifier="azureai:/456", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
        ),
    ]

    def _assert_data_sources_are_the_same(self, one, other):
        """Convenience method to compare lists."""
        one_list = [ds.asset_identifier for ds in one]
        other_list = [ds.asset_identifier for ds in other]
        assert sorted(one_list) == sorted(other_list)

    @pytest.mark.parametrize(
        "files,ds",
        [
            (_FILES, None),
            (_FILES, []),
            (None, _DATA_SOURCES),
            ([], _DATA_SOURCES),
            ([], []),
            ([], None),
            (None, []),
            (None, None),
        ],
    )
    def test_create_code_interpreter_tool(self, files, ds):
        """Test The created ToolResources."""
        code_interpreter = _models.CodeInterpreterTool(file_ids=files, data_sources=ds)
        tool_resources = code_interpreter.resources
        if not files and not ds:
            assert tool_resources.code_interpreter is None
        if files:
            assert sorted(tool_resources.code_interpreter.file_ids) == sorted(files)
            assert tool_resources.code_interpreter.data_sources is None
        if ds:
            assert tool_resources.code_interpreter.file_ids is None
            self._assert_data_sources_are_the_same(tool_resources.code_interpreter.data_sources, ds)

    def test_assert_code_interpreter_tool_raises(self):
        """Test that if both file_ids and data sources are provided the error is raised."""
        with pytest.raises(ValueError) as e:
            _models.CodeInterpreterTool(file_ids=TestModels._FILES, data_sources=TestModels._DATA_SOURCES)
        assert _models.CodeInterpreterTool._INVALID_CONFIGURATION == e.value.args[0]

    @pytest.mark.parametrize(
        "file_id,expected",
        [
            (_FILES[0], _FILES),
            ("file3", _FILES + ["file3"]),
        ],
    )
    def test_add_files(self, file_id, expected):
        """Test addition of files to code interpreter tool."""
        code_interpreter = _models.CodeInterpreterTool(file_ids=TestModels._FILES)
        code_interpreter.add_file(file_id)
        tool_resources = code_interpreter.resources
        assert tool_resources.code_interpreter is not None
        assert sorted(tool_resources.code_interpreter.file_ids) == sorted(expected)

    @pytest.mark.parametrize(
        "ds,expected",
        [
            (_DATA_SOURCES[0], _DATA_SOURCES),
            (
                _models.VectorStoreDataSource(
                    asset_identifier="azureai://789", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
                ),
                _DATA_SOURCES
                + [
                    _models.VectorStoreDataSource(
                        asset_identifier="azureai://789", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
                    )
                ],
            ),
        ],
    )
    def test_add_data_sources(self, ds, expected):
        """Test addition of data sources."""
        code_interpreter = _models.CodeInterpreterTool(data_sources=TestModels._DATA_SOURCES)
        code_interpreter.add_data_source(ds)
        tool_resources = code_interpreter.resources
        assert tool_resources.code_interpreter is not None
        self._assert_data_sources_are_the_same(tool_resources.code_interpreter.data_sources, expected)

    def test_add_files_raises(self):
        """Test that addition of file to the CodeInterpreter with existing data sources raises the exception."""
        code_interpreter = _models.CodeInterpreterTool(data_sources=TestModels._DATA_SOURCES)
        with pytest.raises(ValueError) as e:
            code_interpreter.add_file("123")
        assert _models.CodeInterpreterTool._INVALID_CONFIGURATION == e.value.args[0]

    def test_add_data_source_raises(self):
        """Test that addition of a data source to CodeInterpreter with file IDs raises the exception."""
        code_interpreter = _models.CodeInterpreterTool(file_ids=TestModels._FILES)
        with pytest.raises(ValueError) as e:
            code_interpreter.add_data_source(
                _models.VectorStoreDataSource(
                    asset_identifier="azureai://789", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
                )
            )
        assert _models.CodeInterpreterTool._INVALID_CONFIGURATION == e.value.args[0]

    @pytest.mark.parametrize(
        "file_id,expected",
        [
            (_FILES[0], [_FILES[1]]),
            ("file3", _FILES),
        ],
    )
    def test_remove_fie_id(self, file_id, expected):
        """Test removal of a file ID."""
        code_interpreter = _models.CodeInterpreterTool(file_ids=TestModels._FILES)
        code_interpreter.remove_file(file_id)
        tool_resources = code_interpreter.resources
        assert tool_resources.code_interpreter is not None
        assert sorted(tool_resources.code_interpreter.file_ids) == sorted(expected)

    def test_remove_all_ids(
        self,
    ):
        """Test removal of all file IDs."""
        code_interpreter = _models.CodeInterpreterTool(file_ids=TestModels._FILES)
        for file_id in TestModels._FILES:
            code_interpreter.remove_file(file_id)
        tool_resources = code_interpreter.resources
        assert tool_resources.code_interpreter is None

    @pytest.mark.parametrize(
        "ds,expected",
        [
            (_DATA_SOURCES[0], [_DATA_SOURCES[1]]),
            (
                _models.VectorStoreDataSource(
                    asset_identifier="azureai://789", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
                ),
                _DATA_SOURCES,
            ),
        ],
    )
    def test_remode_data_source(self, ds, expected):
        """Test removal of a data source."""
        code_interpreter = _models.CodeInterpreterTool(data_sources=TestModels._DATA_SOURCES)
        code_interpreter.remove_data_source(ds.asset_identifier)
        tool_resources = code_interpreter.resources
        assert tool_resources.code_interpreter is not None
        self._assert_data_sources_are_the_same(tool_resources.code_interpreter.data_sources, expected)

    def test_remove_all_data_sources(self):
        """Test removal of all data sources."""
        code_interpreter = _models.CodeInterpreterTool(data_sources=TestModels._DATA_SOURCES)
        for ds in TestModels._DATA_SOURCES:
            code_interpreter.remove_data_source(ds.asset_identifier)
        tool_resources = code_interpreter.resources
        assert tool_resources.code_interpreter is None