1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
from typing import (
TYPE_CHECKING,
Any,
Generic,
List,
Mapping,
Optional,
Type,
TypeVar,
)
from pydantic import BaseModel
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from beanie.odm.cache import LRUCache
from beanie.odm.interfaces.clone import CloneInterface
from beanie.odm.interfaces.session import SessionMethods
from beanie.odm.queries.cursor import BaseCursorQuery
from beanie.odm.utils.projection import get_projection
if TYPE_CHECKING:
from beanie.odm.documents import DocType
AggregationProjectionType = TypeVar("AggregationProjectionType")
class AggregationQuery(
Generic[AggregationProjectionType],
BaseCursorQuery[AggregationProjectionType],
SessionMethods,
CloneInterface,
):
"""
Aggregation Query
"""
def __init__(
self,
document_model: Type["DocType"],
aggregation_pipeline: List[Mapping[str, Any]],
find_query: Mapping[str, Any],
projection_model: Optional[Type[BaseModel]] = None,
ignore_cache: bool = False,
**pymongo_kwargs: Any,
):
self.aggregation_pipeline: List[Mapping[str, Any]] = (
aggregation_pipeline
)
self.document_model = document_model
self.projection_model = projection_model
self.find_query = find_query
self.session = None
self.ignore_cache = ignore_cache
self.pymongo_kwargs = pymongo_kwargs
@property
def _cache_key(self) -> str:
return LRUCache.create_key(
{
"type": "Aggregation",
"filter": self.find_query,
"pipeline": self.aggregation_pipeline,
"projection": get_projection(self.projection_model)
if self.projection_model
else None,
}
)
def _get_cache(self):
if (
self.document_model.get_settings().use_cache
and self.ignore_cache is False
):
return self.document_model._cache.get(self._cache_key) # type: ignore
else:
return None
def _set_cache(self, data):
if (
self.document_model.get_settings().use_cache
and self.ignore_cache is False
):
return self.document_model._cache.set(self._cache_key, data) # type: ignore
def get_aggregation_pipeline(
self,
) -> List[Mapping[str, Any]]:
match_pipeline: List[Mapping[str, Any]] = (
[{"$match": self.find_query}] if self.find_query else []
)
projection_pipeline: List[Mapping[str, Any]] = []
if self.projection_model:
projection = get_projection(self.projection_model)
if projection is not None:
projection_pipeline = [{"$project": projection}]
return match_pipeline + self.aggregation_pipeline + projection_pipeline
async def get_cursor(self) -> AsyncCommandCursor:
aggregation_pipeline = self.get_aggregation_pipeline()
return await self.document_model.get_pymongo_collection().aggregate(
aggregation_pipeline, session=self.session, **self.pymongo_kwargs
)
def get_projection_model(self) -> Optional[Type[BaseModel]]:
return self.projection_model
|