File: responses.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (62 lines) | stat: -rw-r--r-- 2,444 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import base64
import json

from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse
from moto.moto_api._internal import mock_random as random

from .models import SageMakerRuntimeBackend, sagemakerruntime_backends


class SageMakerRuntimeResponse(BaseResponse):
    """Handler for SageMakerRuntime requests and responses."""

    def __init__(self) -> None:
        super().__init__(service_name="sagemaker-runtime")

    @property
    def sagemakerruntime_backend(self) -> SageMakerRuntimeBackend:
        """Return backend instance specific for this region."""
        return sagemakerruntime_backends[self.current_account][self.region]

    def invoke_endpoint(self) -> TYPE_RESPONSE:
        params = self._get_params()
        unique_repr = {
            key: value
            for key, value in self.headers.items()
            if key.lower().startswith("x-amzn-sagemaker")
        }
        unique_repr["Accept"] = self.headers.get("Accept")
        unique_repr["Body"] = self.body
        endpoint_name = params.get("EndpointName")
        (
            body,
            content_type,
            invoked_production_variant,
            custom_attributes,
        ) = self.sagemakerruntime_backend.invoke_endpoint(
            endpoint_name=endpoint_name,  # type: ignore[arg-type]
            unique_repr=base64.b64encode(json.dumps(unique_repr).encode("utf-8")),
        )
        headers = {"Content-Type": content_type}
        if invoked_production_variant:
            headers["x-Amzn-Invoked-Production-Variant"] = invoked_production_variant
        if custom_attributes:
            headers["X-Amzn-SageMaker-Custom-Attributes"] = custom_attributes
        return 200, headers, body

    def invoke_endpoint_async(self) -> TYPE_RESPONSE:
        endpoint_name = self.path.split("/")[2]
        input_location = self.headers.get("X-Amzn-SageMaker-InputLocation")
        inference_id = self.headers.get("X-Amzn-SageMaker-Inference-Id")
        output_location, failure_location = (
            self.sagemakerruntime_backend.invoke_endpoint_async(
                endpoint_name, input_location
            )
        )
        resp = {"InferenceId": inference_id or str(random.uuid4())}
        headers = {
            "X-Amzn-SageMaker-OutputLocation": output_location,
            "X-Amzn-SageMaker-FailureLocation": failure_location,
        }
        return 200, headers, json.dumps(resp)