File: aggregation.py

package info (click to toggle)
python-beanie 2.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,484 kB
  • sloc: python: 14,427; makefile: 6; sh: 6
file content (105 lines) | stat: -rw-r--r-- 3,313 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import (
    TYPE_CHECKING,
    Any,
    Generic,
    List,
    Mapping,
    Optional,
    Type,
    TypeVar,
)

from pydantic import BaseModel
from pymongo.asynchronous.command_cursor import AsyncCommandCursor

from beanie.odm.cache import LRUCache
from beanie.odm.interfaces.clone import CloneInterface
from beanie.odm.interfaces.session import SessionMethods
from beanie.odm.queries.cursor import BaseCursorQuery
from beanie.odm.utils.projection import get_projection

if TYPE_CHECKING:
    from beanie.odm.documents import DocType

AggregationProjectionType = TypeVar("AggregationProjectionType")


class AggregationQuery(
    Generic[AggregationProjectionType],
    BaseCursorQuery[AggregationProjectionType],
    SessionMethods,
    CloneInterface,
):
    """
    Aggregation Query
    """

    def __init__(
        self,
        document_model: Type["DocType"],
        aggregation_pipeline: List[Mapping[str, Any]],
        find_query: Mapping[str, Any],
        projection_model: Optional[Type[BaseModel]] = None,
        ignore_cache: bool = False,
        **pymongo_kwargs: Any,
    ):
        self.aggregation_pipeline: List[Mapping[str, Any]] = (
            aggregation_pipeline
        )
        self.document_model = document_model
        self.projection_model = projection_model
        self.find_query = find_query
        self.session = None
        self.ignore_cache = ignore_cache
        self.pymongo_kwargs = pymongo_kwargs

    @property
    def _cache_key(self) -> str:
        return LRUCache.create_key(
            {
                "type": "Aggregation",
                "filter": self.find_query,
                "pipeline": self.aggregation_pipeline,
                "projection": get_projection(self.projection_model)
                if self.projection_model
                else None,
            }
        )

    def _get_cache(self):
        if (
            self.document_model.get_settings().use_cache
            and self.ignore_cache is False
        ):
            return self.document_model._cache.get(self._cache_key)  # type: ignore
        else:
            return None

    def _set_cache(self, data):
        if (
            self.document_model.get_settings().use_cache
            and self.ignore_cache is False
        ):
            return self.document_model._cache.set(self._cache_key, data)  # type: ignore

    def get_aggregation_pipeline(
        self,
    ) -> List[Mapping[str, Any]]:
        match_pipeline: List[Mapping[str, Any]] = (
            [{"$match": self.find_query}] if self.find_query else []
        )
        projection_pipeline: List[Mapping[str, Any]] = []
        if self.projection_model:
            projection = get_projection(self.projection_model)
            if projection is not None:
                projection_pipeline = [{"$project": projection}]
        return match_pipeline + self.aggregation_pipeline + projection_pipeline

    async def get_cursor(self) -> AsyncCommandCursor:
        aggregation_pipeline = self.get_aggregation_pipeline()
        return await self.document_model.get_pymongo_collection().aggregate(
            aggregation_pipeline, session=self.session, **self.pymongo_kwargs
        )

    def get_projection_model(self) -> Optional[Type[BaseModel]]:
        return self.projection_model