1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
import json
import boto3
import requests
from moto import mock_aws, settings
from moto.s3.utils import bucket_and_name_from_url
# See our Development Tips on writing tests for hints on how to write good tests:
# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
@mock_aws
def test_invoke_endpoint__default_results():
client = boto3.client("sagemaker-runtime", region_name="ap-southeast-1")
body = client.invoke_endpoint(
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
)
assert body["Body"].read() == b"body"
assert body["CustomAttributes"] == "custom_attributes"
@mock_aws
def test_invoke_endpoint():
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
base_url = (
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
)
sagemaker_result = {
"results": [
{
"Body": "first body",
"ContentType": "text/xml",
"InvokedProductionVariant": "prod",
"CustomAttributes": "my_attr",
},
{"Body": "second body"},
]
}
requests.post(
f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
json=sagemaker_result,
)
# Return the first item from the list
body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
assert body["Body"].read() == b"first body"
# Same input -> same output
body = client.invoke_endpoint(EndpointName="asdf", Body="qwer")
assert body["Body"].read() == b"first body"
# Different input -> second item
body = client.invoke_endpoint(
EndpointName="asdf", Body="qwer", Accept="sth", TargetModel="tm"
)
assert body["Body"].read() == b"second body"
@mock_aws
def test_invoke_endpoint_async():
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
base_url = (
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
)
sagemaker_result = {
"results": [
{"data": json.dumps({"first": "output"})},
{
"is_failure": True,
"data": "second failure",
},
]
}
requests.post(
f"http://{base_url}/moto-api/static/sagemaker/async-endpoint-results",
json=sagemaker_result,
)
# Return the first item from the list
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
first_output_location = body["OutputLocation"]
first_failure_location = body["FailureLocation"]
# Same input -> same output
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
assert body["OutputLocation"] == first_output_location
assert body["FailureLocation"] == first_failure_location
s3 = boto3.client("s3", "us-east-1")
bucket_name, obj = bucket_and_name_from_url(first_output_location)
resp = s3.get_object(Bucket=bucket_name, Key=obj)
resp = json.loads(resp["Body"].read().decode("utf-8"))
assert resp == {"first": "output"}
# Different input -> second item
body = client.invoke_endpoint_async(
EndpointName="asdf", InputLocation="asf", InferenceId="sth"
)
second_failure_location = body["FailureLocation"]
assert body["InferenceId"] == "sth"
bucket_name, obj = bucket_and_name_from_url(second_failure_location)
resp = s3.get_object(Bucket=bucket_name, Key=obj)
resp = resp["Body"].read().decode("utf-8")
assert resp == "second failure"
@mock_aws
def test_invoke_endpoint_async_should_read_sync_queue_if_async_not_configured():
client = boto3.client("sagemaker-runtime", region_name="us-east-1")
base_url = (
"localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com"
)
sagemaker_result = {
"results": [
{"Body": "support sync queue for backward compatibility"},
]
}
requests.post(
f"http://{base_url}/moto-api/static/sagemaker/endpoint-results",
json=sagemaker_result,
)
# Return the first item from the list
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
first_output_location = body["OutputLocation"]
first_failure_location = body["FailureLocation"]
# Same input -> same output
body = client.invoke_endpoint_async(EndpointName="asdf", InputLocation="qwer")
assert body["OutputLocation"] == first_output_location
assert body["FailureLocation"] == first_failure_location
s3 = boto3.client("s3", "us-east-1")
bucket_name, obj = bucket_and_name_from_url(first_output_location)
resp = s3.get_object(Bucket=bucket_name, Key=obj)
resp = json.loads(resp["Body"].read().decode("utf-8"))
assert resp["Body"] == "support sync queue for backward compatibility"
|