File: test_sagemaker_models.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (140 lines) | stat: -rw-r--r-- 4,663 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import re

import boto3
import pytest
from botocore.exceptions import ClientError

from moto import mock_aws
from moto.sagemaker.models import VpcConfig

TEST_REGION_NAME = "us-east-1"
TEST_ARN = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
TEST_MODEL_NAME = "MyModelName"


@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
    with mock_aws():
        yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)


class MySageMakerModel:
    def __init__(self, name=None, arn=None, container=None, vpc_config=None):
        self.name = name or TEST_MODEL_NAME
        self.arn = arn or TEST_ARN
        self.container = container or {}
        self.vpc_config = vpc_config or {"sg-groups": ["sg-123"], "subnets": ["123"]}

    def save(self, sagemaker_client):
        vpc_config = VpcConfig(
            self.vpc_config.get("sg-groups"), self.vpc_config.get("subnets")
        )
        resp = sagemaker_client.create_model(
            ModelName=self.name,
            ExecutionRoleArn=self.arn,
            VpcConfig=vpc_config.response_object,
        )
        return resp


def test_describe_model(sagemaker_client):
    test_model = MySageMakerModel()
    test_model.save(sagemaker_client)
    model = sagemaker_client.describe_model(ModelName=TEST_MODEL_NAME)
    assert model.get("ModelName") == TEST_MODEL_NAME


def test_describe_model_not_found(sagemaker_client):
    with pytest.raises(ClientError) as err:
        sagemaker_client.describe_model(ModelName="unknown")
    assert "Could not find model" in err.value.response["Error"]["Message"]


def test_create_model(sagemaker_client):
    vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
    model = sagemaker_client.create_model(
        ModelName=TEST_MODEL_NAME,
        ExecutionRoleArn=TEST_ARN,
        VpcConfig=vpc_config.response_object,
    )
    assert re.match(
        rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$", model["ModelArn"]
    )


def test_delete_model(sagemaker_client):
    test_model = MySageMakerModel()
    test_model.save(sagemaker_client)

    assert len(sagemaker_client.list_models()["Models"]) == 1
    sagemaker_client.delete_model(ModelName=TEST_MODEL_NAME)
    assert len(sagemaker_client.list_models()["Models"]) == 0


def test_delete_model_not_found(sagemaker_client):
    with pytest.raises(ClientError) as err:
        sagemaker_client.delete_model(ModelName="blah")
    assert err.value.response["Error"]["Code"] == "404"


def test_list_models(sagemaker_client):
    test_model = MySageMakerModel()
    test_model.save(sagemaker_client)
    models = sagemaker_client.list_models()
    assert len(models["Models"]) == 1
    assert models["Models"][0]["ModelName"] == TEST_MODEL_NAME
    assert re.match(
        rf"^arn:aws:sagemaker:.*:.*:model/{TEST_MODEL_NAME}$",
        models["Models"][0]["ModelArn"],
    )


def test_list_models_multiple(sagemaker_client):
    name_model_1 = "blah"
    arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
    test_model_1 = MySageMakerModel(name=name_model_1, arn=arn_model_1)
    test_model_1.save(sagemaker_client)

    name_model_2 = "blah2"
    arn_model_2 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar2"
    test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2)
    test_model_2.save(sagemaker_client)
    models = sagemaker_client.list_models()
    assert len(models["Models"]) == 2


def test_list_models_none(sagemaker_client):
    models = sagemaker_client.list_models()
    assert len(models["Models"]) == 0


def test_add_tags_to_model(sagemaker_client):
    model = MySageMakerModel().save(sagemaker_client)
    resource_arn = model["ModelArn"]

    tags = [
        {"Key": "myKey", "Value": "myValue"},
    ]
    response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
    assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

    response = sagemaker_client.list_tags(ResourceArn=resource_arn)
    assert response["Tags"] == tags


def test_delete_tags_from_model(sagemaker_client):
    model = MySageMakerModel().save(sagemaker_client)
    resource_arn = model["ModelArn"]

    tags = [
        {"Key": "myKey", "Value": "myValue"},
    ]
    response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
    assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

    tag_keys = [tag["Key"] for tag in tags]
    response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
    assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

    response = sagemaker_client.list_tags(ResourceArn=resource_arn)
    assert response["Tags"] == []