1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import pytest
import azure.ai.agents.models as _models
class TestModels:
"""Unit tests for models."""
_FILES = ["file1", "file2"]
_DATA_SOURCES = [
_models.VectorStoreDataSource(
asset_identifier="azureai:/123", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
),
_models.VectorStoreDataSource(
asset_identifier="azureai:/456", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
),
]
def _assert_data_sources_are_the_same(self, one, other):
"""Convenience method to compare lists."""
one_list = [ds.asset_identifier for ds in one]
other_list = [ds.asset_identifier for ds in other]
assert sorted(one_list) == sorted(other_list)
@pytest.mark.parametrize(
"files,ds",
[
(_FILES, None),
(_FILES, []),
(None, _DATA_SOURCES),
([], _DATA_SOURCES),
([], []),
([], None),
(None, []),
(None, None),
],
)
def test_create_code_interpreter_tool(self, files, ds):
"""Test The created ToolResources."""
code_interpreter = _models.CodeInterpreterTool(file_ids=files, data_sources=ds)
tool_resources = code_interpreter.resources
if not files and not ds:
assert tool_resources.code_interpreter is None
if files:
assert sorted(tool_resources.code_interpreter.file_ids) == sorted(files)
assert tool_resources.code_interpreter.data_sources is None
if ds:
assert tool_resources.code_interpreter.file_ids is None
self._assert_data_sources_are_the_same(tool_resources.code_interpreter.data_sources, ds)
def test_assert_code_interpreter_tool_raises(self):
"""Test that if both file_ids and data sources are provided the error is raised."""
with pytest.raises(ValueError) as e:
_models.CodeInterpreterTool(file_ids=TestModels._FILES, data_sources=TestModels._DATA_SOURCES)
assert _models.CodeInterpreterTool._INVALID_CONFIGURATION == e.value.args[0]
@pytest.mark.parametrize(
"file_id,expected",
[
(_FILES[0], _FILES),
("file3", _FILES + ["file3"]),
],
)
def test_add_files(self, file_id, expected):
"""Test addition of files to code interpreter tool."""
code_interpreter = _models.CodeInterpreterTool(file_ids=TestModels._FILES)
code_interpreter.add_file(file_id)
tool_resources = code_interpreter.resources
assert tool_resources.code_interpreter is not None
assert sorted(tool_resources.code_interpreter.file_ids) == sorted(expected)
@pytest.mark.parametrize(
"ds,expected",
[
(_DATA_SOURCES[0], _DATA_SOURCES),
(
_models.VectorStoreDataSource(
asset_identifier="azureai://789", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
),
_DATA_SOURCES
+ [
_models.VectorStoreDataSource(
asset_identifier="azureai://789", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
)
],
),
],
)
def test_add_data_sources(self, ds, expected):
"""Test addition of data sources."""
code_interpreter = _models.CodeInterpreterTool(data_sources=TestModels._DATA_SOURCES)
code_interpreter.add_data_source(ds)
tool_resources = code_interpreter.resources
assert tool_resources.code_interpreter is not None
self._assert_data_sources_are_the_same(tool_resources.code_interpreter.data_sources, expected)
def test_add_files_raises(self):
"""Test that addition of file to the CodeInterpreter with existing data sources raises the exception."""
code_interpreter = _models.CodeInterpreterTool(data_sources=TestModels._DATA_SOURCES)
with pytest.raises(ValueError) as e:
code_interpreter.add_file("123")
assert _models.CodeInterpreterTool._INVALID_CONFIGURATION == e.value.args[0]
def test_add_data_source_raises(self):
"""Test that addition of a data source to CodeInterpreter with file IDs raises the exception."""
code_interpreter = _models.CodeInterpreterTool(file_ids=TestModels._FILES)
with pytest.raises(ValueError) as e:
code_interpreter.add_data_source(
_models.VectorStoreDataSource(
asset_identifier="azureai://789", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
)
)
assert _models.CodeInterpreterTool._INVALID_CONFIGURATION == e.value.args[0]
@pytest.mark.parametrize(
"file_id,expected",
[
(_FILES[0], [_FILES[1]]),
("file3", _FILES),
],
)
def test_remove_fie_id(self, file_id, expected):
"""Test removal of a file ID."""
code_interpreter = _models.CodeInterpreterTool(file_ids=TestModels._FILES)
code_interpreter.remove_file(file_id)
tool_resources = code_interpreter.resources
assert tool_resources.code_interpreter is not None
assert sorted(tool_resources.code_interpreter.file_ids) == sorted(expected)
def test_remove_all_ids(
self,
):
"""Test removal of all file IDs."""
code_interpreter = _models.CodeInterpreterTool(file_ids=TestModels._FILES)
for file_id in TestModels._FILES:
code_interpreter.remove_file(file_id)
tool_resources = code_interpreter.resources
assert tool_resources.code_interpreter is None
@pytest.mark.parametrize(
"ds,expected",
[
(_DATA_SOURCES[0], [_DATA_SOURCES[1]]),
(
_models.VectorStoreDataSource(
asset_identifier="azureai://789", asset_type=_models.VectorStoreDataSourceAssetType.URI_ASSET
),
_DATA_SOURCES,
),
],
)
def test_remode_data_source(self, ds, expected):
"""Test removal of a data source."""
code_interpreter = _models.CodeInterpreterTool(data_sources=TestModels._DATA_SOURCES)
code_interpreter.remove_data_source(ds.asset_identifier)
tool_resources = code_interpreter.resources
assert tool_resources.code_interpreter is not None
self._assert_data_sources_are_the_same(tool_resources.code_interpreter.data_sources, expected)
def test_remove_all_data_sources(self):
"""Test removal of all data sources."""
code_interpreter = _models.CodeInterpreterTool(data_sources=TestModels._DATA_SOURCES)
for ds in TestModels._DATA_SOURCES:
code_interpreter.remove_data_source(ds.asset_identifier)
tool_resources = code_interpreter.resources
assert tool_resources.code_interpreter is None
|