1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
|
import json
from django.core import checks
from django.db.models import NOT_PROVIDED, Field
from django.db.models.expressions import ColPairs
from django.db.models.fields.tuple_lookups import (
TupleExact,
TupleGreaterThan,
TupleGreaterThanOrEqual,
TupleIn,
TupleIsNull,
TupleLessThan,
TupleLessThanOrEqual,
)
from django.utils.functional import cached_property
class AttributeSetter:
def __init__(self, name, value):
setattr(self, name, value)
class CompositeAttribute:
def __init__(self, field):
self.field = field
@property
def attnames(self):
return [field.attname for field in self.field.fields]
def __get__(self, instance, cls=None):
return tuple(getattr(instance, attname) for attname in self.attnames)
def __set__(self, instance, values):
attnames = self.attnames
length = len(attnames)
if values is None:
values = (None,) * length
if not isinstance(values, (list, tuple)):
raise ValueError(f"{self.field.name!r} must be a list or a tuple.")
if length != len(values):
raise ValueError(f"{self.field.name!r} must have {length} elements.")
for attname, value in zip(attnames, values):
setattr(instance, attname, value)
class CompositePrimaryKey(Field):
descriptor_class = CompositeAttribute
def __init__(self, *args, **kwargs):
if (
not args
or not all(isinstance(field, str) for field in args)
or len(set(args)) != len(args)
):
raise ValueError("CompositePrimaryKey args must be unique strings.")
if len(args) == 1:
raise ValueError("CompositePrimaryKey must include at least two fields.")
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("CompositePrimaryKey cannot have a default.")
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("CompositePrimaryKey cannot have a database default.")
if kwargs.get("db_column", None) is not None:
raise ValueError("CompositePrimaryKey cannot have a db_column.")
if kwargs.setdefault("editable", False):
raise ValueError("CompositePrimaryKey cannot be editable.")
if not kwargs.setdefault("primary_key", True):
raise ValueError("CompositePrimaryKey must be a primary key.")
if not kwargs.setdefault("blank", True):
raise ValueError("CompositePrimaryKey must be blank.")
self.field_names = args
super().__init__(**kwargs)
def deconstruct(self):
# args is always [] so it can be ignored.
name, path, _, kwargs = super().deconstruct()
return name, path, self.field_names, kwargs
@cached_property
def fields(self):
meta = self.model._meta
return tuple(meta.get_field(field_name) for field_name in self.field_names)
@cached_property
def columns(self):
return tuple(field.column for field in self.fields)
def contribute_to_class(self, cls, name, private_only=False):
super().contribute_to_class(cls, name, private_only=private_only)
cls._meta.pk = self
setattr(cls, self.attname, self.descriptor_class(self))
def get_attname_column(self):
return self.get_attname(), None
def __iter__(self):
return iter(self.fields)
def __len__(self):
return len(self.field_names)
@cached_property
def cached_col(self):
return ColPairs(self.model._meta.db_table, self.fields, self.fields, self)
def get_col(self, alias, output_field=None):
if alias == self.model._meta.db_table and (
output_field is None or output_field == self
):
return self.cached_col
return ColPairs(alias, self.fields, self.fields, output_field)
def get_pk_value_on_save(self, instance):
values = []
for field in self.fields:
value = field.value_from_object(instance)
if value is None:
value = field.get_pk_value_on_save(instance)
values.append(value)
return tuple(values)
def _check_field_name(self):
if self.name == "pk":
return []
return [
checks.Error(
"'CompositePrimaryKey' must be named 'pk'.",
obj=self,
id="fields.E013",
)
]
def value_to_string(self, obj):
values = []
vals = self.value_from_object(obj)
for field, value in zip(self.fields, vals):
obj = AttributeSetter(field.attname, value)
values.append(field.value_to_string(obj))
return json.dumps(values, ensure_ascii=False)
def to_python(self, value):
if isinstance(value, str):
# Assume we're deserializing.
vals = json.loads(value)
value = [
field.to_python(val)
for field, val in zip(self.fields, vals, strict=True)
]
return value
CompositePrimaryKey.register_lookup(TupleExact)
CompositePrimaryKey.register_lookup(TupleGreaterThan)
CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual)
CompositePrimaryKey.register_lookup(TupleLessThan)
CompositePrimaryKey.register_lookup(TupleLessThanOrEqual)
CompositePrimaryKey.register_lookup(TupleIn)
CompositePrimaryKey.register_lookup(TupleIsNull)
def unnest(fields):
result = []
for field in fields:
if isinstance(field, CompositePrimaryKey):
result.extend(field.fields)
else:
result.append(field)
return result
|