File: test_sagemakerruntime.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (143 lines) | stat: -rw-r--r-- 4,884 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import json

import boto3
import requests

from moto import mock_aws, settings
from moto.s3.utils import bucket_and_name_from_url

# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html


@mock_aws
def test_invoke_endpoint__default_results():
    client = boto3.client("sagemaker-runtime", region_name="ap-southeast-1")
    body = client.invoke_endpoint(
        EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
    )

    assert body["Body"].read() == b"body"
    assert body["CustomAttributes"] == "custom_attributes"


@mock_aws
def test_invoke_endpoint():
    client = boto3.client("sagemaker-runtime", region_name="us-east-1")
    base_url = (
        "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
    )

    sagemaker_result = {
        "results": [
            {
                "Body": "first body",
                "ContentType": "text/xml",
                "InvokedProductionVariant": "prod",
                "CustomAttributes": "my_attr",
            },
            {"Body": "second body"},
        ]
    }
    requests.post(
        f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
        json=sagemaker_result,
    )

    # Return the first item from the list
    body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
    assert body["Body"].read() == b"first body"

    # Same input -> same output
    body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
    assert body["Body"].read() == b"first body"

    # Different input -> second item
    body = client.invoke_endpoint(
        EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
    )
    assert body["Body"].read() == b"second body"


@mock_aws
def test_invoke_endpoint_async():
    client = boto3.client("sagemaker-runtime", region_name="us-east-1")
    base_url = (
        "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
    )

    sagemaker_result = {
        "results": [
            {"data": json.dumps({"first": "output"})},
            {
                "is_failure": True,
                "data": "second failure",
            },
        ]
    }
    requests.post(
        f"http://{base_url}/moto-api/static/sagemaker/async-endpoint-results",
        json=sagemaker_result,
    )

    # Return the first item from the list
    body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
    first_output_location = body["OutputLocation"]
    first_failure_location = body["FailureLocation"]

    # Same input -> same output
    body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
    assert body["OutputLocation"] == first_output_location
    assert body["FailureLocation"] == first_failure_location

    s3 = boto3.client("s3", "us-east-1")
    bucket_name, obj = bucket_and_name_from_url(first_output_location)
    resp = s3.get_object(Bucket=bucket_name, Key=obj)
    resp = json.loads(resp["Body"].read().decode("utf-8"))
    assert resp == {"first": "output"}

    # Different input -> second item
    body = client.invoke_endpoint_async(
        EndpointName="asdf", InputLocation="asf", InferenceId="sth"
    )
    second_failure_location = body["FailureLocation"]
    assert body["InferenceId"] == "sth"

    bucket_name, obj = bucket_and_name_from_url(second_failure_location)
    resp = s3.get_object(Bucket=bucket_name, Key=obj)
    resp = resp["Body"].read().decode("utf-8")
    assert resp == "second failure"


@mock_aws
def test_invoke_endpoint_async_should_read_sync_queue_if_async_not_configured():
    client = boto3.client("sagemaker-runtime", region_name="us-east-1")
    base_url = (
        "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
    )

    sagemaker_result = {
        "results": [
            {"Body": "support sync queue for backward compatibility"},
        ]
    }
    requests.post(
        f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
        json=sagemaker_result,
    )

    # Return the first item from the list
    body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
    first_output_location = body["OutputLocation"]
    first_failure_location = body["FailureLocation"]

    # Same input -> same output
    body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
    assert body["OutputLocation"] == first_output_location
    assert body["FailureLocation"] == first_failure_location

    s3 = boto3.client("s3", "us-east-1")
    bucket_name, obj = bucket_and_name_from_url(first_output_location)
    resp = s3.get_object(Bucket=bucket_name, Key=obj)
    resp = json.loads(resp["Body"].read().decode("utf-8"))
    assert resp["Body"] == "support sync queue for backward compatibility"