File: utils.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (135 lines) | stat: -rw-r--r-- 4,578 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import typing
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional

from dateutil.tz import tzutc

from moto.s3.models import s3_backends
from moto.utilities.utils import get_partition

from .exceptions import ValidationError

if typing.TYPE_CHECKING:
    from .models import FakeModelCard, FakePipeline, FakePipelineExecution


def get_pipeline_from_name(
    pipelines: dict[str, "FakePipeline"], pipeline_name: str
) -> "FakePipeline":
    try:
        return pipelines[pipeline_name]
    except KeyError:
        raise ValidationError(
            message=f"Could not find pipeline with PipelineName {pipeline_name}."
        )


def get_pipeline_name_from_execution_arn(pipeline_execution_arn: str) -> str:
    return pipeline_execution_arn.split("/")[1].split(":")[-1]


def get_pipeline_execution_from_arn(
    pipelines: dict[str, "FakePipeline"], pipeline_execution_arn: str
) -> "FakePipelineExecution":
    try:
        pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn)
        pipeline = get_pipeline_from_name(pipelines, pipeline_name)
        return pipeline.pipeline_executions[pipeline_execution_arn]
    except KeyError:
        raise ValidationError(
            message=f"Could not find pipeline execution with PipelineExecutionArn {pipeline_execution_arn}."
        )


def load_pipeline_definition_from_s3(
    pipeline_definition_s3_location: dict[str, Any], account_id: str, partition: str
) -> dict[str, Any]:
    s3_backend = s3_backends[account_id][partition]
    result = s3_backend.get_object(
        bucket_name=pipeline_definition_s3_location["Bucket"],
        key_name=pipeline_definition_s3_location["ObjectKey"],
    )
    return json.loads(result.value)  # type: ignore[union-attr]


def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str:
    return f"arn:{get_partition(region_name)}:sagemaker:{region_name}:{account_id}:{_type}/{_id}"


def validate_model_approval_status(model_approval_status: typing.Optional[str]) -> None:
    if model_approval_status is not None and model_approval_status not in [
        "Approved",
        "Rejected",
        "PendingManualApproval",
    ]:
        raise ValidationError(
            f"Value '{model_approval_status}' at 'modelApprovalStatus' failed to satisfy constraint: "
            "Member must satisfy enum value set: [PendingManualApproval, Approved, Rejected]"
        )


def filter_model_cards(
    model_cards: defaultdict[str, list["FakeModelCard"]],
    creation_time_after: Optional[datetime],
    creation_time_before: Optional[datetime],
    name_contains: Optional[str],
    model_card_status: Optional[str],
    sort_by: Optional[str],
    sort_order: Optional[str],
) -> list["FakeModelCard"]:
    reverse = sort_order == "Descending"

    if name_contains:
        lowercase_name = name_contains.lower()
        filtered_cards = {
            k: v for k, v in model_cards.items() if lowercase_name in k.lower()
        }
    else:
        filtered_cards = dict(model_cards.items())

    result: list[FakeModelCard] = []
    for _, versions in filtered_cards.items():
        filtered_versions = versions

        if creation_time_after:
            if isinstance(creation_time_after, int):
                creation_time_after = datetime.fromtimestamp(
                    creation_time_after, tz=tzutc()
                )
            filtered_versions = [
                v
                for v in filtered_versions
                if v.last_modified_time > str(creation_time_after)
            ]

        if creation_time_before:
            if isinstance(creation_time_before, int):
                creation_time_before = datetime.fromtimestamp(
                    creation_time_before, tz=tzutc()
                )
            filtered_versions = [
                v
                for v in filtered_versions
                if v.last_modified_time < str(creation_time_before)
            ]

        if model_card_status:
            filtered_versions = [
                v for v in filtered_versions if v.model_card_status == model_card_status
            ]

        if filtered_versions:
            latest_version = max(filtered_versions, key=lambda v: v.last_modified_time)
            result.append(latest_version)

    if not result:
        return []

    def sort_key(x: "FakeModelCard") -> str:
        if sort_by == "Name":
            return x.model_card_name
        return x.creation_time

    return sorted(result, key=sort_key, reverse=reverse)