File: _builder.py

package info (click to toggle)
python-mashumaro 3.17-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,408 kB
  • sloc: python: 19,981; sh: 16; makefile: 5
file content (106 lines) | stat: -rw-r--r-- 4,065 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import re
from collections.abc import Callable
from typing import Any, Optional, Type

from mashumaro.core.meta.code.builder import CodeBuilder
from mashumaro.core.meta.helpers import is_optional, is_type_var_any
from mashumaro.core.meta.types.common import (
    AttrsHolder,
    FieldContext,
    ValueSpec,
)
from mashumaro.core.meta.types.pack import PackerRegistry
from mashumaro.core.meta.types.unpack import UnpackerRegistry

CALL_EXPR = re.compile(r"^([^ ]+)\(value\)$")


class CodecCodeBuilder(CodeBuilder):
    @classmethod
    def new(cls, **kwargs: Any) -> "CodecCodeBuilder":
        if "attrs" not in kwargs:
            kwargs["attrs"] = AttrsHolder()
        return cls(AttrsHolder("__root__"), **kwargs)  # type: ignore

    def add_decode_method(
        self,
        shape_type: Type,
        decoder_obj: Any,
        pre_decoder_func: Optional[Callable[[Any], Any]] = None,
    ) -> None:
        self.reset()
        with self.indent("def decode(value):"):
            if pre_decoder_func:
                self.ensure_object_imported(pre_decoder_func, "decoder")
                self.add_line("value = decoder(value)")
            could_be_none = (
                shape_type in (Any, type(None), None)
                or is_type_var_any(self.get_real_type("", shape_type))
                or is_optional(
                    shape_type, self.get_field_resolved_type_params("")
                )
            )
            unpacked_value = UnpackerRegistry.get(
                ValueSpec(
                    type=shape_type,
                    expression="value",
                    builder=self,
                    field_ctx=FieldContext(name="", metadata={}),
                    could_be_none=could_be_none,
                )
            )
            self.add_line(f"return {unpacked_value}")
        self.add_line("setattr(decoder_obj, 'decode', decode)")
        if pre_decoder_func is None:
            m = CALL_EXPR.match(unpacked_value)
            if m:
                method_name = m.group(1)
                self.lines.reset()
                self.add_line(f"setattr(decoder_obj, 'decode', {method_name})")
        self.ensure_object_imported(decoder_obj, "decoder_obj")
        self.ensure_object_imported(self.cls, "cls")
        self.compile()

    def add_encode_method(
        self,
        shape_type: Type,
        encoder_obj: Any,
        post_encoder_func: Optional[Callable[[Any], Any]] = None,
    ) -> None:
        self.reset()
        with self.indent("def encode(value):"):
            could_be_none = (
                shape_type in (Any, type(None), None)
                or is_type_var_any(self.get_real_type("", shape_type))
                or is_optional(
                    shape_type, self.get_field_resolved_type_params("")
                )
            )
            packed_value = PackerRegistry.get(
                ValueSpec(
                    type=shape_type,
                    expression="value",
                    builder=self,
                    field_ctx=FieldContext(name="", metadata={}),
                    could_be_none=could_be_none,
                    no_copy_collections=self.get_dialect_or_config_option(
                        "no_copy_collections", ()
                    ),
                )
            )
            if post_encoder_func:
                self.ensure_object_imported(post_encoder_func, "encoder")
                self.add_line(f"return encoder({packed_value})")
            else:
                self.add_line(f"return {packed_value}")
        self.add_line("setattr(encoder_obj, 'encode', encode)")
        if post_encoder_func is None:
            m = CALL_EXPR.match(packed_value)
            if m:
                method_name = m.group(1)
                self.lines.reset()
                self.add_line(f"setattr(encoder_obj, 'encode', {method_name})")
        self.ensure_object_imported(encoder_obj, "encoder_obj")
        self.ensure_object_imported(self.cls, "cls")
        self.ensure_object_imported(self.cls, "self")
        self.compile()