File: models.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (157 lines) | stat: -rw-r--r-- 5,438 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import json
import time
from collections import defaultdict
from typing import Any, Optional

from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.moto_api._internal import mock_random
from moto.sns import sns_backends
from moto.sns.exceptions import TopicNotFound

from .exceptions import InvalidJobIdException, InvalidParameterException


class TextractJobStatus:
    in_progress = "IN_PROGRESS"
    succeeded = "SUCCEEDED"
    failed = "FAILED"
    partial_success = "PARTIAL_SUCCESS"


class TextractJob(BaseModel):
    def __init__(
        self, job: dict[str, Any], notification_channel: Optional[dict[str, str]] = None
    ):
        self.job = job
        self.notification_channel = notification_channel
        self.job_id = str(mock_random.uuid4())

    def to_dict(self) -> dict[str, Any]:
        return self.job

    def send_completion_notification(
        self, account_id: str, region_name: str, document_location: dict[str, Any]
    ) -> None:
        if not self.notification_channel:
            return

        topic_arn = self.notification_channel.get("SNSTopicArn")
        if not topic_arn:
            return

        # Convert document_location from {'S3Object': {'Bucket': '...', 'Name': '...'}} format
        # to {'S3Bucket': '...', 'S3ObjectName': '...'} format as per AWS docs
        s3_object = document_location.get("S3Object", {})
        doc_location = {
            "S3Bucket": s3_object.get("Bucket", ""),
            "S3ObjectName": s3_object.get("Name", ""),
        }

        notification = {
            "JobId": self.job_id,
            "Status": self.job["JobStatus"],
            "API": "StartDocumentTextDetection",
            "JobTag": "",  # Not implemented yet
            "Timestamp": int(time.time() * 1000),  # Convert to milliseconds
            "DocumentLocation": doc_location,
        }

        sns_backend = sns_backends[account_id][region_name]
        try:
            sns_backend.publish(
                message=json.dumps(notification),  # SNS requires message to be a string
                arn=topic_arn,
                subject="Amazon Textract Job Completion",
            )
        except TopicNotFound:
            pass


class TextractBackend(BaseBackend):
    """Implementation of Textract APIs."""

    JOB_STATUS = TextractJobStatus.succeeded
    PAGES = {"Pages": mock_random.randint(5, 500)}
    BLOCKS: list[dict[str, Any]] = []

    def __init__(self, region_name: str, account_id: str):
        super().__init__(region_name, account_id)
        self.async_text_detection_jobs: dict[str, TextractJob] = defaultdict()
        self.async_document_analysis_jobs: dict[str, TextractJob] = defaultdict()

    def get_document_text_detection(self, job_id: str) -> TextractJob:
        """
        Pagination has not yet been implemented
        """
        job = self.async_text_detection_jobs.get(job_id)
        if not job:
            raise InvalidJobIdException()
        return job

    def detect_document_text(self) -> dict[str, Any]:
        return {
            "Blocks": TextractBackend.BLOCKS,
            "DetectDocumentTextModelVersion": "1.0",
            "DocumentMetadata": {"Pages": TextractBackend.PAGES},
        }

    def start_document_text_detection(
        self,
        document_location: dict[str, Any],
        notification_channel: Optional[dict[str, str]] = None,
    ) -> str:
        """
        The following parameters have not yet been implemented: ClientRequestToken, JobTag, OutputConfig, KmsKeyID
        """
        if not document_location:
            raise InvalidParameterException()

        job = TextractJob(
            {
                "Blocks": TextractBackend.BLOCKS,
                "DetectDocumentTextModelVersion": "1.0",
                "DocumentMetadata": {"Pages": TextractBackend.PAGES},
                "JobStatus": TextractBackend.JOB_STATUS,
            },
            notification_channel=notification_channel,
        )

        self.async_text_detection_jobs[job.job_id] = job

        # Send completion notification since we're mocking an immediate completion
        job.send_completion_notification(
            self.account_id, self.region_name, document_location
        )

        return job.job_id

    def start_document_analysis(
        self, document_location: dict[str, Any], feature_types: list[str]
    ) -> str:
        """
        The following parameters have not yet been implemented: ClientRequestToken, JobTag, NotificationChannel, OutputConfig, KmsKeyID
        """
        if not document_location or not feature_types:
            raise InvalidParameterException()
        job_id = str(mock_random.uuid4())
        self.async_document_analysis_jobs[job_id] = TextractJob(
            {
                "Blocks": TextractBackend.BLOCKS,
                "DetectDocumentTextModelVersion": "1.0",
                "DocumentMetadata": {"Pages": TextractBackend.PAGES},
                "JobStatus": TextractBackend.JOB_STATUS,
            }
        )
        return job_id

    def get_document_analysis(
        self, job_id: str, max_results: Optional[int], next_token: Optional[str] = None
    ) -> TextractJob:
        job = self.async_document_analysis_jobs.get(job_id)
        if not job:
            raise InvalidJobIdException()
        return job


textract_backends = BackendDict(TextractBackend, "textract")