1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
import json
import time
from collections import defaultdict
from typing import Any, Optional
from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.moto_api._internal import mock_random
from moto.sns import sns_backends
from moto.sns.exceptions import TopicNotFound
from .exceptions import InvalidJobIdException, InvalidParameterException
class TextractJobStatus:
in_progress = "IN_PROGRESS"
succeeded = "SUCCEEDED"
failed = "FAILED"
partial_success = "PARTIAL_SUCCESS"
class TextractJob(BaseModel):
def __init__(
self, job: dict[str, Any], notification_channel: Optional[dict[str, str]] = None
):
self.job = job
self.notification_channel = notification_channel
self.job_id = str(mock_random.uuid4())
def to_dict(self) -> dict[str, Any]:
return self.job
def send_completion_notification(
self, account_id: str, region_name: str, document_location: dict[str, Any]
) -> None:
if not self.notification_channel:
return
topic_arn = self.notification_channel.get("SNSTopicArn")
if not topic_arn:
return
# Convert document_location from {'S3Object': {'Bucket': '...', 'Name': '...'}} format
# to {'S3Bucket': '...', 'S3ObjectName': '...'} format as per AWS docs
s3_object = document_location.get("S3Object", {})
doc_location = {
"S3Bucket": s3_object.get("Bucket", ""),
"S3ObjectName": s3_object.get("Name", ""),
}
notification = {
"JobId": self.job_id,
"Status": self.job["JobStatus"],
"API": "StartDocumentTextDetection",
"JobTag": "", # Not implemented yet
"Timestamp": int(time.time() * 1000), # Convert to milliseconds
"DocumentLocation": doc_location,
}
sns_backend = sns_backends[account_id][region_name]
try:
sns_backend.publish(
message=json.dumps(notification), # SNS requires message to be a string
arn=topic_arn,
subject="Amazon Textract Job Completion",
)
except TopicNotFound:
pass
class TextractBackend(BaseBackend):
"""Implementation of Textract APIs."""
JOB_STATUS = TextractJobStatus.succeeded
PAGES = {"Pages": mock_random.randint(5, 500)}
BLOCKS: list[dict[str, Any]] = []
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.async_text_detection_jobs: dict[str, TextractJob] = defaultdict()
self.async_document_analysis_jobs: dict[str, TextractJob] = defaultdict()
def get_document_text_detection(self, job_id: str) -> TextractJob:
"""
Pagination has not yet been implemented
"""
job = self.async_text_detection_jobs.get(job_id)
if not job:
raise InvalidJobIdException()
return job
def detect_document_text(self) -> dict[str, Any]:
return {
"Blocks": TextractBackend.BLOCKS,
"DetectDocumentTextModelVersion": "1.0",
"DocumentMetadata": {"Pages": TextractBackend.PAGES},
}
def start_document_text_detection(
self,
document_location: dict[str, Any],
notification_channel: Optional[dict[str, str]] = None,
) -> str:
"""
The following parameters have not yet been implemented: ClientRequestToken, JobTag, OutputConfig, KmsKeyID
"""
if not document_location:
raise InvalidParameterException()
job = TextractJob(
{
"Blocks": TextractBackend.BLOCKS,
"DetectDocumentTextModelVersion": "1.0",
"DocumentMetadata": {"Pages": TextractBackend.PAGES},
"JobStatus": TextractBackend.JOB_STATUS,
},
notification_channel=notification_channel,
)
self.async_text_detection_jobs[job.job_id] = job
# Send completion notification since we're mocking an immediate completion
job.send_completion_notification(
self.account_id, self.region_name, document_location
)
return job.job_id
def start_document_analysis(
self, document_location: dict[str, Any], feature_types: list[str]
) -> str:
"""
The following parameters have not yet been implemented: ClientRequestToken, JobTag, NotificationChannel, OutputConfig, KmsKeyID
"""
if not document_location or not feature_types:
raise InvalidParameterException()
job_id = str(mock_random.uuid4())
self.async_document_analysis_jobs[job_id] = TextractJob(
{
"Blocks": TextractBackend.BLOCKS,
"DetectDocumentTextModelVersion": "1.0",
"DocumentMetadata": {"Pages": TextractBackend.PAGES},
"JobStatus": TextractBackend.JOB_STATUS,
}
)
return job_id
def get_document_analysis(
self, job_id: str, max_results: Optional[int], next_token: Optional[str] = None
) -> TextractJob:
job = self.async_document_analysis_jobs.get(job_id)
if not job:
raise InvalidJobIdException()
return job
textract_backends = BackendDict(TextractBackend, "textract")
|