File: polyfield.py

package info (click to toggle)
python-marshmallow-polyfield 5.11-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 168 kB
  • sloc: python: 651; sh: 7; makefile: 4
file content (135 lines) | stat: -rw-r--r-- 5,368 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import abc
import contextlib

from marshmallow import Schema, ValidationError
from marshmallow.fields import Field


class PolyFieldBase(Field, metaclass=abc.ABCMeta):
    def __init__(self, many=False, **metadata):
        super().__init__(**metadata)
        self.many = many

    def _deserialize(self, value, attr, parent, partial=None, **kwargs):
        if not self.many:
            value = [value]

        results = []
        for v in value:
            deserializer = None
            try:
                deserializer = self.deserialization_schema_selector(v, parent)
                if isinstance(deserializer, type):
                    deserializer = deserializer()
                if not isinstance(deserializer, (Field, Schema)):
                    raise Exception('Invalid deserializer type')
            except TypeError as te:
                raise ValidationError(str(te)) from te
            except ValidationError:
                raise
            except Exception as err:
                class_type = None
                if deserializer:
                    class_type = str(type(deserializer))

                raise ValidationError(
                    "Unable to use schema. Error: {err}\n"
                    "Ensure there is a deserialization_schema_selector"
                    " and then it returns a field or a schema when the function is passed in "
                    "{value_passed}. This is the class I got. "
                    "Make sure it is a field or a schema: {class_type}".format(
                        err=err,
                        value_passed=v,
                        class_type=class_type
                    )
                ) from err

            # Will raise ValidationError if any problems
            if isinstance(deserializer, Field):
                data = deserializer.deserialize(v, attr, parent)
            else:
                deserializer.context.update(getattr(self, 'context', {}))
                data = deserializer.load(v, partial=partial)

            results.append(data)

        if self.many:
            return results
        else:
            # Will be at least one otherwise value would have been None
            return results[0]

    def _serialize(self, value, key, obj, **kwargs):
        if value is None:
            return None
        try:
            if self.many:
                res = []
                for v in value:
                    schema = self.serialization_schema_selector(v, obj)
                    if isinstance(schema, type):
                        schema = schema()
                    with contextlib.suppress(AttributeError, TypeError):
                        schema.context.update(getattr(self, 'context', {}))
                    serialized = (schema.dump(v)
                                  if hasattr(schema, 'dump')
                                  else schema._serialize(v, None, None))
                    res.append(serialized)
                return res
            else:
                schema = self.serialization_schema_selector(value, obj)
                if isinstance(schema, type):
                    schema = schema()
                with contextlib.suppress(AttributeError, TypeError):
                    schema.context.update(getattr(self, 'context', {}))
                return (schema.dump(value)
                        if hasattr(schema, 'dump')
                        else schema._serialize(value, None, None))
        except Exception as err:
            raise TypeError(
                'Failed to serialize object. Error: {0}\n'
                ' Ensure the serialization_schema_selector exists and '
                ' returns a Schema and that schema'
                ' can serialize this value {1}'.format(err, value))

    @abc.abstractmethod
    def serialization_schema_selector(self, value, obj):
        raise NotImplementedError

    @abc.abstractmethod
    def deserialization_schema_selector(self, value, obj):
        raise NotImplementedError


class PolyField(PolyFieldBase):
    """
    A field that (de)serializes to one of many types. Passed in functions
    are called to disambiguate what schema to use for the (de)serialization
    Intended to assist in working with fields that can contain any subclass
    of a base type
    """
    def __init__(
            self,
            serialization_schema_selector=None,
            deserialization_schema_selector=None,
            many=False,
            **metadata
    ):
        """
        :param serialization_schema_selector: Function that takes in either
        an object representing that object, it's parent object
        and returns the appropriate schema.
        :param deserialization_schema_selector: Function that takes in either
        an a dict representing that object, dict representing it's parent dict
        and returns the appropriate schema

        """
        super().__init__(many=many, **metadata)
        self._serialization_schema_selector_arg = serialization_schema_selector
        self._deserialization_schema_selector_arg = deserialization_schema_selector

    def serialization_schema_selector(self, value, obj):
        return self._serialization_schema_selector_arg(value, obj)

    def deserialization_schema_selector(self, value, obj):
        return self._deserialization_schema_selector_arg(value, obj)