File: msgspec_factory.py

package info (click to toggle)
python-polyfactory 2.22.2-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,892 kB
  • sloc: python: 11,338; makefile: 103; sh: 37
file content (71 lines) | stat: -rw-r--r-- 2,421 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import annotations

from inspect import isclass
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar

from typing_extensions import get_type_hints

from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.field_meta import FieldMeta, Null
from polyfactory.value_generators.constrained_numbers import handle_constrained_int
from polyfactory.value_generators.primitives import create_random_bytes

if TYPE_CHECKING:
    from typing_extensions import TypeGuard

try:
    import msgspec
    from msgspec.structs import fields
except ImportError as e:
    msg = "msgspec is not installed"
    raise MissingDependencyException(msg) from e

T = TypeVar("T", bound=msgspec.Struct)


class MsgspecFactory(Generic[T], BaseFactory[T]):
    """Base factory for msgspec Structs."""

    __is_base_factory__ = True

    @classmethod
    def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
        def get_msgpack_ext() -> msgspec.msgpack.Ext:
            code = handle_constrained_int(cls.__random__, ge=-128, le=127)
            data = create_random_bytes(cls.__random__)
            return msgspec.msgpack.Ext(code, data)

        msgspec_provider_map = {msgspec.UnsetType: lambda: msgspec.UNSET, msgspec.msgpack.Ext: get_msgpack_ext}

        provider_map = super().get_provider_map()
        provider_map.update(msgspec_provider_map)

        return provider_map

    @classmethod
    def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:
        return isclass(value) and hasattr(value, "__struct_fields__")

    @classmethod
    def get_model_fields(cls) -> list[FieldMeta]:
        fields_meta: list[FieldMeta] = []

        type_hints = get_type_hints(cls.__model__, include_extras=True)
        for field in fields(cls.__model__):
            annotation = type_hints[field.name]
            if field.default is not msgspec.NODEFAULT:
                default_value = field.default
            elif field.default_factory is not msgspec.NODEFAULT:
                default_value = field.default_factory()
            else:
                default_value = Null

            fields_meta.append(
                FieldMeta.from_type(
                    annotation=annotation,
                    name=field.name,
                    default=default_value,
                ),
            )
        return fields_meta