File: state_machine.py

package info (click to toggle)
libmongocrypt 1.17.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,572 kB
  • sloc: ansic: 70,067; python: 4,547; cpp: 615; sh: 460; makefile: 44; awk: 8
file content (153 lines) | stat: -rw-r--r-- 5,087 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod

from pymongocrypt.asynchronous.credentials import _ask_for_kms_credentials
from pymongocrypt.binding import lib
from pymongocrypt.compat import ABC
from pymongocrypt.errors import MongoCryptError


class AsyncMongoCryptCallback(ABC):
    """Callback ABC to perform I/O on behalf of libbmongocrypt."""

    @abstractmethod
    async def kms_request(self, kms_context):
        """Complete a KMS request.

        :Parameters:
          - `kms_context`: A :class:`MongoCryptKmsContext`.

        :Returns:
          None
        """

    @abstractmethod
    async def collection_info(self, database, filter):
        """Get the collection info for a namespace.

        The returned collection info is passed to libmongocrypt which reads
        the JSON schema.

        :Parameters:
          - `database`: The database on which to run listCollections.
          - `filter`: The filter to pass to listCollections.

        :Returns:
          The all or first document from the listCollections command response as BSON.
        """

    @abstractmethod
    async def mark_command(self, database, cmd):
        """Mark a command for encryption.

        :Parameters:
          - `database`: The database on which to run this command.
          - `cmd`: The BSON command to run.

        :Returns:
          The marked command response from mongocryptd.
        """

    @abstractmethod
    async def fetch_keys(self, filter):
        """Yields one or more keys from the key vault.

        :Parameters:
          - `filter`: The filter to pass to find.

        :Returns:
          A generator which yields the requested keys from the key vault.
        """

    @abstractmethod
    async def insert_data_key(self, data_key):
        """Insert a data key into the key vault.

        :Parameters:
          - `data_key`: The data key document to insert.

        :Returns:
          The _id of the inserted data key document.
        """

    @abstractmethod
    def bson_encode(self, doc):
        """Encode a document to BSON.

        A document can be any mapping type (like :class:`dict`).

        :Parameters:
          - `doc`: mapping type representing a document

        :Returns:
          The encoded BSON bytes.
        """

    @abstractmethod
    async def close(self):
        """Release resources."""


async def run_state_machine(ctx, callback):
    """Run the libmongocrypt state machine until completion.

    :Parameters:
      - `ctx`: A :class:`MongoCryptContext`.
      - `callback`: A :class:`AsyncMongoCryptCallback`.

    :Returns:
      The completed libmongocrypt operation.
    """
    while True:
        state = ctx.state
        # Check for terminal states first.
        if state == lib.MONGOCRYPT_CTX_ERROR:
            ctx._raise_from_status()
        elif state == lib.MONGOCRYPT_CTX_READY:
            return ctx.finish()
        elif state == lib.MONGOCRYPT_CTX_DONE:
            return None

        if state == lib.MONGOCRYPT_CTX_NEED_MONGO_COLLINFO:
            list_colls_filter = ctx.mongo_operation()
            coll_info = await callback.collection_info(ctx.database, list_colls_filter)
            if coll_info:
                if isinstance(coll_info, list):
                    for i in coll_info:
                        ctx.add_mongo_operation_result(i)
                else:
                    ctx.add_mongo_operation_result(coll_info)
            ctx.complete_mongo_operation()
        elif state == lib.MONGOCRYPT_CTX_NEED_MONGO_MARKINGS:
            mongocryptd_cmd = ctx.mongo_operation()
            result = await callback.mark_command(ctx.database, mongocryptd_cmd)
            ctx.add_mongo_operation_result(result)
            ctx.complete_mongo_operation()
        elif state == lib.MONGOCRYPT_CTX_NEED_MONGO_KEYS:
            key_filter = ctx.mongo_operation()
            async for key in callback.fetch_keys(key_filter):
                ctx.add_mongo_operation_result(key)
            ctx.complete_mongo_operation()
        elif state == lib.MONGOCRYPT_CTX_NEED_KMS:
            for kms_ctx in ctx.kms_contexts():
                with kms_ctx:
                    await callback.kms_request(kms_ctx)
            ctx.complete_kms()
        elif state == lib.MONGOCRYPT_CTX_NEED_KMS_CREDENTIALS:
            creds = await _ask_for_kms_credentials(ctx.kms_providers)
            ctx.provide_kms_providers(callback.bson_encode(creds))
        else:
            raise MongoCryptError(f"unknown state: {state}")