1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
|
import sys
import warnings
from typing import Generic, Optional, TypeVar, Union, cast
import strawberry
from django.db import DEFAULT_DB_ALIAS
from django.db.models import Count, QuerySet, Window
from django.db.models.functions import RowNumber
from django.db.models.query import MAX_GET_RESULTS # type: ignore
from strawberry.types import Info
from strawberry.types.arguments import StrawberryArgument
from strawberry.types.unset import UNSET, UnsetType
from typing_extensions import Self
from strawberry_django.fields.base import StrawberryDjangoFieldBase
from strawberry_django.resolvers import django_resolver
from .arguments import argument
from .settings import strawberry_django_settings
NodeType = TypeVar("NodeType")
_QS = TypeVar("_QS", bound=QuerySet)
PAGINATION_ARG = "pagination"
@strawberry.type
class OffsetPaginationInfo:
offset: int = 0
limit: Optional[int] = UNSET
@strawberry.input
class OffsetPaginationInput(OffsetPaginationInfo): ...
@strawberry.type
class OffsetPaginated(Generic[NodeType]):
queryset: strawberry.Private[Optional[QuerySet]]
pagination: strawberry.Private[OffsetPaginationInput]
@strawberry.field
def page_info(self) -> OffsetPaginationInfo:
return OffsetPaginationInfo(
limit=self.pagination.limit,
offset=self.pagination.offset,
)
@strawberry.field(description="Total count of existing results.")
@django_resolver
def total_count(self) -> int:
return self.get_total_count()
@strawberry.field(description="List of paginated results.")
@django_resolver
def results(self) -> list[NodeType]:
paginated_queryset = self.get_paginated_queryset()
return cast(
"list[NodeType]",
paginated_queryset if paginated_queryset is not None else [],
)
@classmethod
def resolve_paginated(
cls,
queryset: QuerySet,
*,
info: Info,
pagination: Optional[OffsetPaginationInput] = None,
**kwargs,
) -> Self:
"""Resolve the paginated queryset.
Args:
queryset: The queryset to be paginated.
info: The strawberry execution info resolve the type name from.
pagination: The pagination input to be applied.
kwargs: Additional arguments passed to the resolver.
Returns:
The resolved `OffsetPaginated`
"""
return cls(
queryset=queryset,
pagination=pagination or OffsetPaginationInput(),
)
def get_total_count(self) -> int:
"""Retrieve tht total count of the queryset without pagination."""
return get_total_count(self.queryset) if self.queryset is not None else 0
def get_paginated_queryset(self) -> Optional[QuerySet]:
"""Retrieve the queryset with pagination applied.
This will apply the paginated arguments to the queryset and return it.
To use the original queryset, access `.queryset` directly.
"""
from strawberry_django.optimizer import is_optimized_by_prefetching
if self.queryset is None:
return None
return (
self.queryset._result_cache # type: ignore
if is_optimized_by_prefetching(self.queryset)
else apply(self.pagination, self.queryset)
)
def apply(
pagination: Optional[object],
queryset: _QS,
*,
related_field_id: Optional[str] = None,
) -> _QS:
"""Apply pagination to a queryset.
Args:
----
pagination: The pagination input.
queryset: The queryset to apply pagination to.
related_field_id: The related field id to apply pagination to.
When provided, the pagination will be applied using window functions
instead of slicing the queryset.
Useful for prefetches, as those cannot be sliced after being filtered
"""
if pagination in (None, strawberry.UNSET): # noqa: PLR6201
return queryset
if not isinstance(pagination, OffsetPaginationInput):
raise TypeError(f"Don't know how to resolve pagination {pagination!r}")
if related_field_id is not None:
queryset = apply_window_pagination(
queryset,
related_field_id=related_field_id,
offset=pagination.offset,
limit=pagination.limit,
)
else:
start = pagination.offset
limit = pagination.limit
if limit is UNSET:
settings = strawberry_django_settings()
limit = settings["PAGINATION_DEFAULT_LIMIT"]
if limit is not None and limit >= 0:
stop = start + limit
queryset = queryset[start:stop]
else:
queryset = queryset[start:]
return queryset
class _PaginationWindow(Window):
"""Window function to be used for pagination.
This is the same as django's `Window` function, but we can easily identify
it in case we need to remove it from the queryset, as there might be other
window functions in the queryset and no other way to identify ours.
"""
def apply_window_pagination(
queryset: _QS,
*,
related_field_id: str,
offset: int = 0,
limit: Optional[int] = UNSET,
max_results: Optional[int] = None,
) -> _QS:
"""Apply pagination using window functions.
Useful for prefetches, as those cannot be sliced after being filtered.
This is based on the same solution that Django uses, which was implemented
in https://github.com/django/django/pull/15957
Args:
----
queryset: The queryset to apply pagination to.
related_field_id: The related field id to apply pagination to.
offset: The offset to start the pagination from.
limit: The limit of items to return.
"""
order_by = [
expr
for expr, _ in queryset.query.get_compiler(
using=queryset._db or DEFAULT_DB_ALIAS # type: ignore
).get_order_by()
]
queryset = queryset.annotate(
_strawberry_row_number=_PaginationWindow(
RowNumber(),
partition_by=related_field_id,
order_by=order_by,
),
_strawberry_total_count=_PaginationWindow(
Count(1),
partition_by=related_field_id,
),
)
if offset:
queryset = queryset.filter(_strawberry_row_number__gt=offset)
if limit is UNSET:
settings = strawberry_django_settings()
limit = (
max_results
if max_results is not None
else settings["PAGINATION_DEFAULT_LIMIT"]
)
# Limit == -1 means no limit. sys.maxsize is set by relay when paginating
# from the end to as a way to mimic a "not limit" as well
if limit is not None and limit >= 0 and limit != sys.maxsize:
queryset = queryset.filter(_strawberry_row_number__lte=offset + limit)
return queryset
def remove_window_pagination(queryset: _QS) -> _QS:
"""Remove pagination window functions from a queryset.
Utility function to remove the pagination `WHERE` clause added by
the `apply_window_pagination` function.
Args:
----
queryset: The queryset to apply pagination to.
"""
queryset = queryset._chain() # type: ignore
queryset.query.where.children = [
child
for child in queryset.query.where.children
if (not hasattr(child, "lhs") or not isinstance(child.lhs, _PaginationWindow))
]
return queryset
def get_total_count(queryset: QuerySet) -> int:
"""Get the total count of a queryset.
Try to get the total count from the queryset cache, if it's optimized by
prefetching. Otherwise, fallback to the `QuerySet.count()` method.
"""
from strawberry_django.optimizer import is_optimized_by_prefetching
if is_optimized_by_prefetching(queryset):
results = queryset._result_cache # type: ignore
if results:
try:
return results[0]._strawberry_total_count
except AttributeError:
warnings.warn(
(
"Pagination annotations not found, falling back to QuerySet resolution. "
"This might cause n+1 issues..."
),
RuntimeWarning,
stacklevel=2,
)
# If we have no results, we can't get the total count from the cache.
# In this case we will remove the pagination filter to be able to `.count()`
# the whole queryset with its original filters.
queryset = remove_window_pagination(queryset)
return queryset.count()
class StrawberryDjangoPagination(StrawberryDjangoFieldBase):
def __init__(self, pagination: Union[bool, UnsetType] = UNSET, **kwargs):
self.pagination = pagination
super().__init__(**kwargs)
def __copy__(self) -> Self:
new_field = super().__copy__()
new_field.pagination = self.pagination
return new_field
def _has_pagination(self) -> bool:
if isinstance(self.pagination, bool):
return self.pagination
if self.is_paginated:
return True
django_type = self.django_type
if django_type is not None and not issubclass(
django_type, strawberry.relay.Node
):
return django_type.__strawberry_django_definition__.pagination
return False
@property
def arguments(self) -> list[StrawberryArgument]:
arguments = []
if (
self.base_resolver is None
and (self.is_list or self.is_paginated)
and not self.is_model_property
):
pagination = self.get_pagination()
if pagination is not None:
arguments.append(
argument("pagination", OffsetPaginationInput, is_optional=True),
)
return super().arguments + arguments
@arguments.setter
def arguments(self, value: list[StrawberryArgument]):
args_prop = super(StrawberryDjangoPagination, self.__class__).arguments
return args_prop.fset(self, value) # type: ignore
def get_pagination(self) -> Optional[type]:
return OffsetPaginationInput if self._has_pagination() else None
def apply_pagination(
self,
queryset: _QS,
pagination: Optional[object] = None,
*,
related_field_id: Optional[str] = None,
) -> _QS:
return apply(pagination, queryset, related_field_id=related_field_id)
def get_queryset(
self,
queryset: _QS,
info: Info,
*,
pagination: Optional[OffsetPaginationInput] = None,
_strawberry_related_field_id: Optional[str] = None,
**kwargs,
) -> _QS:
queryset = super().get_queryset(queryset, info, **kwargs)
# If the queryset is not ordered, and this field is either going to return
# multiple records, or call `.first()`, then order by the primary key to ensure
# deterministic results.
if not queryset.ordered and (
self.is_list or self.is_paginated or self.is_connection or self.is_optional
):
queryset = queryset.order_by("pk")
# This is counter intuitive, but in case we are returning a `Paginated`
# result, we want to set the original queryset _as is_ as it will apply
# the pagination later on when resolving its `.results` field.
# Check `get_wrapped_result` below for more details.
if self.is_paginated:
return queryset
# Add implicit pagination if this field is not a list
# that way when first() / get() is called on the QuerySet it does not cause extra queries
# and we don't prefetch more than necessary
if (
not pagination
and not (self.is_list or self.is_paginated or self.is_connection)
and not _strawberry_related_field_id
):
if self.is_optional:
pagination = OffsetPaginationInput(offset=0, limit=1)
else:
pagination = OffsetPaginationInput(offset=0, limit=MAX_GET_RESULTS)
return self.apply_pagination(
queryset,
pagination,
related_field_id=_strawberry_related_field_id,
)
|