File: models.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (133 lines) | stat: -rw-r--r-- 5,157 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations

import base64
import os
from typing import Any, Optional

from moto.core.base_backend import BackendDict, BaseBackend
from moto.core.common_models import BaseModel
from moto.dynamodb.models import DynamoDBBackend, dynamodb_backends
from moto.dynamodb.models.table import StreamShard, Table


class ShardIterator(BaseModel):
    def __init__(
        self,
        streams_backend: DynamoDBStreamsBackend,
        stream_shard: StreamShard,
        shard_iterator_type: str,
        sequence_number: Optional[int] = None,
    ):
        self.id = base64.b64encode(os.urandom(472)).decode("utf-8")
        self.streams_backend = streams_backend
        self.stream_shard = stream_shard
        self.shard_iterator_type = shard_iterator_type
        if shard_iterator_type == "TRIM_HORIZON":
            self.sequence_number = stream_shard.starting_sequence_number
        elif shard_iterator_type == "LATEST":
            self.sequence_number = stream_shard.starting_sequence_number + len(
                stream_shard.items
            )
        elif shard_iterator_type == "AT_SEQUENCE_NUMBER":
            self.sequence_number = sequence_number  # type: ignore[assignment]
        elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER":
            self.sequence_number = sequence_number + 1  # type: ignore[operator]

    @property
    def arn(self) -> str:
        return f"{self.stream_shard.table.table_arn}/stream/{self.stream_shard.table.latest_stream_label}|1|{self.id}"

    def get(self, limit: int = 1000) -> dict[str, Any]:
        items = self.stream_shard.get(self.sequence_number, limit)
        try:
            last_sequence_number = max(
                int(i["dynamodb"]["SequenceNumber"]) for i in items
            )
            new_shard_iterator = ShardIterator(
                self.streams_backend,
                self.stream_shard,
                "AFTER_SEQUENCE_NUMBER",
                last_sequence_number,
            )
        except ValueError:
            new_shard_iterator = ShardIterator(
                self.streams_backend,
                self.stream_shard,
                "AT_SEQUENCE_NUMBER",
                self.sequence_number,
            )
        self.streams_backend.shard_iterators[new_shard_iterator.arn] = (
            new_shard_iterator
        )
        return {"NextShardIterator": new_shard_iterator.arn, "Records": items}


class DynamoDBStreamsBackend(BaseBackend):
    def __init__(self, region_name: str, account_id: str):
        super().__init__(region_name, account_id)
        self.shard_iterators: dict[str, ShardIterator] = {}

    @property
    def dynamodb(self) -> DynamoDBBackend:
        return dynamodb_backends[self.account_id][self.region_name]

    def _get_table_from_arn(self, arn: str) -> Table:
        table_name = arn.split(":", 6)[5].split("/")[1]
        return self.dynamodb.get_table(table_name)

    def describe_stream(self, arn: str) -> dict[str, Any]:
        table = self._get_table_from_arn(arn)
        stream = {
            "StreamArn": arn,
            "StreamLabel": table.latest_stream_label,
            "StreamStatus": ("ENABLED" if table.latest_stream_label else "DISABLED"),
            "StreamViewType": table.stream_specification["StreamViewType"],  # type: ignore[index]
            "CreationRequestDateTime": table.stream_shard.created_on,  # type: ignore[union-attr]
            "TableName": table.name,
            "KeySchema": table.schema,
            "Shards": ([table.stream_shard.to_json()] if table.stream_shard else []),
        }
        return stream

    def list_streams(self, table_name: Optional[str] = None) -> list[dict[str, Any]]:
        streams = []
        for table in self.dynamodb.tables.values():
            if table_name is not None and table.name != table_name:
                continue
            if table.latest_stream_label:
                d = table.describe(base_key="Table")
                streams.append(
                    {
                        "StreamArn": d["Table"]["LatestStreamArn"],
                        "TableName": d["Table"]["TableName"],
                        "StreamLabel": d["Table"]["LatestStreamLabel"],
                    }
                )
        return streams

    def get_shard_iterator(
        self,
        arn: str,
        shard_id: str,
        shard_iterator_type: str,
        sequence_number: Optional[str] = None,
    ) -> ShardIterator:
        table = self._get_table_from_arn(arn)
        assert table.stream_shard.id == shard_id  # type: ignore[union-attr]

        shard_iterator = ShardIterator(
            self,
            table.stream_shard,  # type: ignore[arg-type]
            shard_iterator_type,
            sequence_number,  # type: ignore[arg-type]
        )
        self.shard_iterators[shard_iterator.arn] = shard_iterator

        return shard_iterator

    def get_records(self, iterator_arn: str, limit: int) -> dict[str, Any]:
        shard_iterator = self.shard_iterators[iterator_arn]
        return shard_iterator.get(limit)


dynamodbstreams_backends = BackendDict(DynamoDBStreamsBackend, "dynamodbstreams")