1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
|
from django.core import checks
from django.db import connections, router
from django.db.models.sql import Query
from django.utils.functional import cached_property
from . import NOT_PROVIDED, Field
__all__ = ["GeneratedField"]
class GeneratedField(Field):
generated = True
db_returning = True
_query = None
output_field = None
def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
if kwargs.setdefault("editable", False):
raise ValueError("GeneratedField cannot be editable.")
if not kwargs.setdefault("blank", True):
raise ValueError("GeneratedField must be blank.")
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("GeneratedField cannot have a default.")
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("GeneratedField cannot have a database default.")
if db_persist not in (True, False):
raise ValueError("GeneratedField.db_persist must be True or False.")
self.expression = expression
self.output_field = output_field
self.db_persist = db_persist
super().__init__(**kwargs)
@cached_property
def cached_col(self):
from django.db.models.expressions import Col
return Col(self.model._meta.db_table, self, self.output_field)
def get_col(self, alias, output_field=None):
if alias != self.model._meta.db_table and output_field in (None, self):
output_field = self.output_field
return super().get_col(alias, output_field)
def contribute_to_class(self, *args, **kwargs):
super().contribute_to_class(*args, **kwargs)
self._query = Query(model=self.model, alias_cols=False)
# Register lookups from the output_field class.
for lookup_name, lookup in self.output_field.get_class_lookups().items():
self.register_lookup(lookup, lookup_name=lookup_name)
def generated_sql(self, connection):
compiler = connection.ops.compiler("SQLCompiler")(
self._query, connection=connection, using=None
)
resolved_expression = self.expression.resolve_expression(
self._query, allow_joins=False
)
sql, params = compiler.compile(resolved_expression)
if (
getattr(self.expression, "conditional", False)
and not connection.features.supports_boolean_expr_in_select_clause
):
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
return sql, params
def check(self, **kwargs):
databases = kwargs.get("databases") or []
errors = [
*super().check(**kwargs),
*self._check_supported(databases),
*self._check_persistence(databases),
]
output_field_clone = self.output_field.clone()
output_field_clone.model = self.model
output_field_checks = output_field_clone.check(databases=databases)
if output_field_checks:
separator = "\n "
error_messages = separator.join(
f"{output_check.msg} ({output_check.id})"
for output_check in output_field_checks
if isinstance(output_check, checks.Error)
)
if error_messages:
errors.append(
checks.Error(
"GeneratedField.output_field has errors:"
f"{separator}{error_messages}",
obj=self,
id="fields.E223",
)
)
warning_messages = separator.join(
f"{output_check.msg} ({output_check.id})"
for output_check in output_field_checks
if isinstance(output_check, checks.Warning)
)
if warning_messages:
errors.append(
checks.Warning(
"GeneratedField.output_field has warnings:"
f"{separator}{warning_messages}",
obj=self,
id="fields.W224",
)
)
return errors
def _check_supported(self, databases):
errors = []
for db in databases:
if not router.allow_migrate_model(db, self.model):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor
and self.model._meta.required_db_vendor != connection.vendor
):
continue
if not (
connection.features.supports_virtual_generated_columns
or "supports_stored_generated_columns"
in self.model._meta.required_db_features
) and not (
connection.features.supports_stored_generated_columns
or "supports_virtual_generated_columns"
in self.model._meta.required_db_features
):
errors.append(
checks.Error(
f"{connection.display_name} does not support GeneratedFields.",
obj=self,
id="fields.E220",
)
)
return errors
def _check_persistence(self, databases):
errors = []
for db in databases:
if not router.allow_migrate_model(db, self.model):
continue
connection = connections[db]
if (
self.model._meta.required_db_vendor
and self.model._meta.required_db_vendor != connection.vendor
):
continue
if not self.db_persist and not (
connection.features.supports_virtual_generated_columns
or "supports_virtual_generated_columns"
in self.model._meta.required_db_features
):
errors.append(
checks.Error(
f"{connection.display_name} does not support non-persisted "
"GeneratedFields.",
obj=self,
id="fields.E221",
hint="Set db_persist=True on the field.",
)
)
if self.db_persist and not (
connection.features.supports_stored_generated_columns
or "supports_stored_generated_columns"
in self.model._meta.required_db_features
):
errors.append(
checks.Error(
f"{connection.display_name} does not support persisted "
"GeneratedFields.",
obj=self,
id="fields.E222",
hint="Set db_persist=False on the field.",
)
)
return errors
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
del kwargs["blank"]
del kwargs["editable"]
kwargs["db_persist"] = self.db_persist
kwargs["expression"] = self.expression
kwargs["output_field"] = self.output_field
return name, path, args, kwargs
def get_internal_type(self):
return self.output_field.get_internal_type()
def db_parameters(self, connection):
return self.output_field.db_parameters(connection)
def db_type_parameters(self, connection):
return self.output_field.db_type_parameters(connection)
|