1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
import re
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from moto.sagemaker.models import VpcConfig
TEST_REGION_NAME = "us-east-1"
TEST_ARN = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
TEST_MODEL_NAME = "MyModelName"
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
with mock_aws():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MySageMakerModel:
def __init__(self, name=None, arn=None, container=None, vpc_config=None):
self.name = name or TEST_MODEL_NAME
self.arn = arn or TEST_ARN
self.container = container or {}
self.vpc_config = vpc_config or {"sg-groups": ["sg-123"], "subnets": ["123"]}
def save(self, sagemaker_client):
vpc_config = VpcConfig(
self.vpc_config.get("sg-groups"), self.vpc_config.get("subnets")
)
resp = sagemaker_client.create_model(
ModelName=self.name,
ExecutionRoleArn=self.arn,
VpcConfig=vpc_config.response_object,
)
return resp
def test_describe_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
model = sagemaker_client.describe_model(ModelName=TEST_MODEL_NAME)
assert model.get("ModelName") == TEST_MODEL_NAME
def test_describe_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.describe_model(ModelName="unknown")
assert "Could not find model" in err.value.response["Error"]["Message"]
def test_create_model(sagemaker_client):
vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
model = sagemaker_client.create_model(
ModelName=TEST_MODEL_NAME,
ExecutionRoleArn=TEST_ARN,
VpcConfig=vpc_config.response_object,
)
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$", model["ModelArn"]
)
def test_delete_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
assert len(sagemaker_client.list_models()["Models"]) == 1
sagemaker_client.delete_model(ModelName=TEST_MODEL_NAME)
assert len(sagemaker_client.list_models()["Models"]) == 0
def test_delete_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.delete_model(ModelName="blah")
assert err.value.response["Error"]["Code"] == "404"
def test_list_models(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
models = sagemaker_client.list_models()
assert len(models["Models"]) == 1
assert models["Models"][0]["ModelName"] == TEST_MODEL_NAME
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$",
models["Models"][0]["ModelArn"],
)
def test_list_models_multiple(sagemaker_client):
name_model_1 = "blah"
arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
test_model_1 = MySageMakerModel(name=name_model_1, arn=arn_model_1)
test_model_1.save(sagemaker_client)
name_model_2 = "blah2"
arn_model_2 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar2"
test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2)
test_model_2.save(sagemaker_client)
models = sagemaker_client.list_models()
assert len(models["Models"]) == 2
def test_list_models_none(sagemaker_client):
models = sagemaker_client.list_models()
assert len(models["Models"]) == 0
def test_add_tags_to_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == tags
def test_delete_tags_from_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
tag_keys = [tag["Key"] for tag in tags]
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == []
|