1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
|
"""
We perform uniqueness checks explicitly on the serializer class, rather
the using Django's `.full_clean()`.
This gives us better separation of concerns, allows us to use single-step
object creation, and makes it possible to switch between using the implicit
`ModelSerializer` class and an equivalent explicit `Serializer` class.
"""
from django.core.exceptions import FieldError
from django.db import DataError
from django.db.models import Exists
from django.utils.translation import gettext_lazy as _
from rest_framework.exceptions import ValidationError
from rest_framework.utils.representation import smart_repr
# Robust filter and exist implementations. Ensures that queryset.exists() for
# an invalid value returns `False`, rather than raising an error.
# Refs https://github.com/encode/django-rest-framework/issues/3381
def qs_exists(queryset):
try:
return queryset.exists()
except (TypeError, ValueError, DataError):
return False
def qs_exists_with_condition(queryset, condition, against):
if condition is None:
return qs_exists(queryset)
try:
# use the same query as UniqueConstraint.validate
# https://github.com/django/django/blob/7ba2a0db20c37a5b1500434ca4ed48022311c171/django/db/models/constraints.py#L672
return (condition & Exists(queryset.filter(condition))).check(against)
except (TypeError, ValueError, DataError, FieldError):
return False
def qs_filter(queryset, **kwargs):
try:
return queryset.filter(**kwargs)
except (TypeError, ValueError, DataError):
return queryset.none()
class UniqueValidator:
"""
Validator that corresponds to `unique=True` on a model field.
Should be applied to an individual field on the serializer.
"""
message = _('This field must be unique.')
requires_context = True
def __init__(self, queryset, message=None, lookup='exact'):
self.queryset = queryset
self.message = message or self.message
self.lookup = lookup
def filter_queryset(self, value, queryset, field_name):
"""
Filter the queryset to all instances matching the given attribute.
"""
filter_kwargs = {'%s__%s' % (field_name, self.lookup): value}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset
def __call__(self, value, serializer_field):
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer_field.parent, 'instance', None)
queryset = self.queryset
queryset = self.filter_queryset(value, queryset, field_name)
queryset = self.exclude_current_instance(queryset, instance)
if qs_exists(queryset):
raise ValidationError(self.message, code='unique')
def __repr__(self):
return '<%s(queryset=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset)
)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.lookup == other.lookup
)
class UniqueTogetherValidator:
"""
Validator that corresponds to `unique_together = (...)` on a model class.
Should be applied to the serializer class, not to an individual field.
"""
message = _('The fields {field_names} must make a unique set.')
missing_message = _('This field is required.')
requires_context = True
def __init__(self, queryset, fields, message=None, condition_fields=None, condition=None):
self.queryset = queryset
self.fields = fields
self.message = message or self.message
self.condition_fields = [] if condition_fields is None else condition_fields
self.condition = condition
def enforce_required_fields(self, attrs, serializer):
"""
The `UniqueTogetherValidator` always forces an implied 'required'
state on the fields it applies to.
"""
if serializer.instance is not None:
return
missing_items = {
field_name: self.missing_message
for field_name in (*self.fields, *self.condition_fields)
if serializer.fields[field_name].source not in attrs
}
if missing_items:
raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset, serializer):
"""
Filter the queryset to all instances matching the given attributes.
"""
# field names => field sources
sources = [
serializer.fields[field_name].source
for field_name in self.fields
]
# If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute.
if serializer.instance is not None:
for source in sources:
if source not in attrs:
attrs[source] = getattr(serializer.instance, source)
# Determine the filter keyword arguments and filter the queryset.
filter_kwargs = {
source: attrs[source]
for source in sources
}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, attrs, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset
def __call__(self, attrs, serializer):
self.enforce_required_fields(attrs, serializer)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset, serializer)
queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
checked_names = [
serializer.fields[field_name].source for field_name in self.fields
]
# Ignore validation if any field is None
if serializer.instance is None:
checked_values = [attrs[field_name] for field_name in checked_names]
else:
# Ignore validation if all field values are unchanged
checked_values = [
attrs[field_name]
for field_name in checked_names
if attrs[field_name] != getattr(serializer.instance, field_name)
]
condition_sources = (serializer.fields[field_name].source for field_name in self.condition_fields)
condition_kwargs = {source: attrs[source] for source in condition_sources}
if checked_values and None not in checked_values and qs_exists_with_condition(queryset, self.condition, condition_kwargs):
field_names = ', '.join(self.fields)
message = self.message.format(field_names=field_names)
raise ValidationError(message, code='unique')
def __repr__(self):
return '<{}({})>'.format(
self.__class__.__name__,
', '.join(
f'{attr}={smart_repr(getattr(self, attr))}'
for attr in ('queryset', 'fields', 'condition')
if getattr(self, attr) is not None)
)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.missing_message == other.missing_message
and self.queryset == other.queryset
and self.fields == other.fields
)
class ProhibitSurrogateCharactersValidator:
message = _('Surrogate characters are not allowed: U+{code_point:X}.')
code = 'surrogate_characters_not_allowed'
def __call__(self, value):
for surrogate_character in (ch for ch in str(value)
if 0xD800 <= ord(ch) <= 0xDFFF):
message = self.message.format(code_point=ord(surrogate_character))
raise ValidationError(message, code=self.code)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.code == other.code
)
class BaseUniqueForValidator:
message = None
missing_message = _('This field is required.')
requires_context = True
def __init__(self, queryset, field, date_field, message=None):
self.queryset = queryset
self.field = field
self.date_field = date_field
self.message = message or self.message
def enforce_required_fields(self, attrs):
"""
The `UniqueFor<Range>Validator` classes always force an implied
'required' state on the fields they are applied to.
"""
missing_items = {
field_name: self.missing_message
for field_name in [self.field, self.date_field]
if field_name not in attrs
}
if missing_items:
raise ValidationError(missing_items, code='required')
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
raise NotImplementedError('`filter_queryset` must be implemented.')
def exclude_current_instance(self, attrs, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset
def __call__(self, attrs, serializer):
# Determine the underlying model field names. These may not be the
# same as the serializer field names if `source=<>` is set.
field_name = serializer.fields[self.field].source_attrs[-1]
date_field_name = serializer.fields[self.date_field].source_attrs[-1]
self.enforce_required_fields(attrs)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
queryset = self.exclude_current_instance(attrs, queryset, serializer.instance)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({
self.field: message
}, code='unique')
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.missing_message == other.missing_message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.field == other.field
and self.date_field == other.date_field
)
def __repr__(self):
return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.field),
smart_repr(self.date_field)
)
class UniqueForDateValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" date.')
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[field_name] = value
filter_kwargs['%s__day' % date_field_name] = date.day
filter_kwargs['%s__month' % date_field_name] = date.month
filter_kwargs['%s__year' % date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)
class UniqueForMonthValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" month.')
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[field_name] = value
filter_kwargs['%s__month' % date_field_name] = date.month
return qs_filter(queryset, **filter_kwargs)
class UniqueForYearValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" year.')
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]
filter_kwargs = {}
filter_kwargs[field_name] = value
filter_kwargs['%s__year' % date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)
|