File: generated.py

package info (click to toggle)
python-django 3%3A5.2.5-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 61,236 kB
  • sloc: python: 361,585; javascript: 19,250; xml: 211; makefile: 182; sh: 28
file content (197 lines) | stat: -rw-r--r-- 7,655 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from django.core import checks
from django.db import connections, router
from django.db.models.sql import Query
from django.utils.functional import cached_property

from . import NOT_PROVIDED, Field

__all__ = ["GeneratedField"]


class GeneratedField(Field):
    generated = True
    db_returning = True

    _query = None
    output_field = None

    def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
        if kwargs.setdefault("editable", False):
            raise ValueError("GeneratedField cannot be editable.")
        if not kwargs.setdefault("blank", True):
            raise ValueError("GeneratedField must be blank.")
        if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
            raise ValueError("GeneratedField cannot have a default.")
        if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
            raise ValueError("GeneratedField cannot have a database default.")
        if db_persist not in (True, False):
            raise ValueError("GeneratedField.db_persist must be True or False.")

        self.expression = expression
        self.output_field = output_field
        self.db_persist = db_persist
        super().__init__(**kwargs)

    @cached_property
    def cached_col(self):
        from django.db.models.expressions import Col

        return Col(self.model._meta.db_table, self, self.output_field)

    def get_col(self, alias, output_field=None):
        if alias != self.model._meta.db_table and output_field in (None, self):
            output_field = self.output_field
        return super().get_col(alias, output_field)

    def contribute_to_class(self, *args, **kwargs):
        super().contribute_to_class(*args, **kwargs)

        self._query = Query(model=self.model, alias_cols=False)
        # Register lookups from the output_field class.
        for lookup_name, lookup in self.output_field.get_class_lookups().items():
            self.register_lookup(lookup, lookup_name=lookup_name)

    def generated_sql(self, connection):
        compiler = connection.ops.compiler("SQLCompiler")(
            self._query, connection=connection, using=None
        )
        resolved_expression = self.expression.resolve_expression(
            self._query, allow_joins=False
        )
        sql, params = compiler.compile(resolved_expression)
        if (
            getattr(self.expression, "conditional", False)
            and not connection.features.supports_boolean_expr_in_select_clause
        ):
            sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
        return sql, params

    def check(self, **kwargs):
        databases = kwargs.get("databases") or []
        errors = [
            *super().check(**kwargs),
            *self._check_supported(databases),
            *self._check_persistence(databases),
        ]
        output_field_clone = self.output_field.clone()
        output_field_clone.model = self.model
        output_field_checks = output_field_clone.check(databases=databases)
        if output_field_checks:
            separator = "\n    "
            error_messages = separator.join(
                f"{output_check.msg} ({output_check.id})"
                for output_check in output_field_checks
                if isinstance(output_check, checks.Error)
            )
            if error_messages:
                errors.append(
                    checks.Error(
                        "GeneratedField.output_field has errors:"
                        f"{separator}{error_messages}",
                        obj=self,
                        id="fields.E223",
                    )
                )
            warning_messages = separator.join(
                f"{output_check.msg} ({output_check.id})"
                for output_check in output_field_checks
                if isinstance(output_check, checks.Warning)
            )
            if warning_messages:
                errors.append(
                    checks.Warning(
                        "GeneratedField.output_field has warnings:"
                        f"{separator}{warning_messages}",
                        obj=self,
                        id="fields.W224",
                    )
                )
        return errors

    def _check_supported(self, databases):
        errors = []
        for db in databases:
            if not router.allow_migrate_model(db, self.model):
                continue
            connection = connections[db]
            if (
                self.model._meta.required_db_vendor
                and self.model._meta.required_db_vendor != connection.vendor
            ):
                continue
            if not (
                connection.features.supports_virtual_generated_columns
                or "supports_stored_generated_columns"
                in self.model._meta.required_db_features
            ) and not (
                connection.features.supports_stored_generated_columns
                or "supports_virtual_generated_columns"
                in self.model._meta.required_db_features
            ):
                errors.append(
                    checks.Error(
                        f"{connection.display_name} does not support GeneratedFields.",
                        obj=self,
                        id="fields.E220",
                    )
                )
        return errors

    def _check_persistence(self, databases):
        errors = []
        for db in databases:
            if not router.allow_migrate_model(db, self.model):
                continue
            connection = connections[db]
            if (
                self.model._meta.required_db_vendor
                and self.model._meta.required_db_vendor != connection.vendor
            ):
                continue
            if not self.db_persist and not (
                connection.features.supports_virtual_generated_columns
                or "supports_virtual_generated_columns"
                in self.model._meta.required_db_features
            ):
                errors.append(
                    checks.Error(
                        f"{connection.display_name} does not support non-persisted "
                        "GeneratedFields.",
                        obj=self,
                        id="fields.E221",
                        hint="Set db_persist=True on the field.",
                    )
                )
            if self.db_persist and not (
                connection.features.supports_stored_generated_columns
                or "supports_stored_generated_columns"
                in self.model._meta.required_db_features
            ):
                errors.append(
                    checks.Error(
                        f"{connection.display_name} does not support persisted "
                        "GeneratedFields.",
                        obj=self,
                        id="fields.E222",
                        hint="Set db_persist=False on the field.",
                    )
                )
        return errors

    def deconstruct(self):
        name, path, args, kwargs = super().deconstruct()
        del kwargs["blank"]
        del kwargs["editable"]
        kwargs["db_persist"] = self.db_persist
        kwargs["expression"] = self.expression
        kwargs["output_field"] = self.output_field
        return name, path, args, kwargs

    def get_internal_type(self):
        return self.output_field.get_internal_type()

    def db_parameters(self, connection):
        return self.output_field.db_parameters(connection)

    def db_type_parameters(self, connection):
        return self.output_field.db_type_parameters(connection)