File: iterative.py

package info (click to toggle)
python-beanie 2.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,484 kB
  • sloc: python: 14,427; makefile: 6; sh: 6
file content (134 lines) | stat: -rw-r--r-- 4,939 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import asyncio
from inspect import isclass, signature
from typing import Any, List, Optional, Type, Union

from beanie.migrations.controllers.base import BaseMigrationController
from beanie.migrations.utils import update_dict
from beanie.odm.documents import Document
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2, parse_model


class DummyOutput:
    def __init__(self):
        super(DummyOutput, self).__setattr__("_internal_structure_dict", {})

    def __setattr__(self, key, value):
        self._internal_structure_dict[key] = value

    def __getattr__(self, item):
        try:
            return self._internal_structure_dict[item]
        except KeyError:
            self._internal_structure_dict[item] = DummyOutput()
            return self._internal_structure_dict[item]

    def dict(self, to_parse: Optional[Union[dict, "DummyOutput"]] = None):
        if to_parse is None:
            to_parse = self
        input_dict = (
            to_parse._internal_structure_dict
            if isinstance(to_parse, DummyOutput)
            else to_parse
        )
        result_dict = {}
        for key, value in input_dict.items():
            if isinstance(value, (DummyOutput, dict)):
                result_dict[key] = self.dict(to_parse=value)
            else:
                result_dict[key] = value
        return result_dict


def iterative_migration(
    document_models: Optional[List[Type[Document]]] = None,
    batch_size: int = 10000,
):
    class IterativeMigration(BaseMigrationController):
        def __init__(self, function):
            self.function = function
            self.function_signature = signature(function)
            input_signature = self.function_signature.parameters.get(
                "input_document"
            )
            if input_signature is None:
                raise RuntimeError("input_signature must not be None")
            self.input_document_model: Type[Document] = (
                input_signature.annotation
            )
            output_signature = self.function_signature.parameters.get(
                "output_document"
            )
            if output_signature is None:
                raise RuntimeError("output_signature must not be None")
            self.output_document_model: Type[Document] = (
                output_signature.annotation
            )

            if (
                not isclass(self.input_document_model)
                or not issubclass(self.input_document_model, Document)
                or not isclass(self.output_document_model)
                or not issubclass(self.output_document_model, Document)
            ):
                raise TypeError(
                    "input_document and output_document "
                    "must have annotation of Document subclass"
                )

            self.batch_size = batch_size

        def __call__(self, *args: Any, **kwargs: Any):
            pass

        @property
        def models(self) -> List[Type[Document]]:
            preset_models = document_models
            if preset_models is None:
                preset_models = []
            return preset_models + [
                self.input_document_model,
                self.output_document_model,
            ]

        async def run(self, session):
            output_documents = []
            all_migration_ops = []
            async for input_document in self.input_document_model.find_all(
                session=session
            ):
                output = DummyOutput()
                function_kwargs = {
                    "input_document": input_document,
                    "output_document": output,
                }
                if "self" in self.function_signature.parameters:
                    function_kwargs["self"] = None
                await self.function(**function_kwargs)
                output_dict = (
                    input_document.dict()
                    if not IS_PYDANTIC_V2
                    else input_document.model_dump()
                )
                update_dict(output_dict, output.dict())
                output_document = parse_model(
                    self.output_document_model, output_dict
                )
                output_documents.append(output_document)

                if len(output_documents) == self.batch_size:
                    all_migration_ops.append(
                        self.output_document_model.replace_many(
                            documents=output_documents, session=session
                        )
                    )
                    output_documents = []

            if output_documents:
                all_migration_ops.append(
                    self.output_document_model.replace_many(
                        documents=output_documents, session=session
                    )
                )
            await asyncio.gather(*all_migration_ops)

    return IterativeMigration