File: client_command.py

package info (click to toggle)
python-ledger-bitcoin 0.4.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 716 kB
  • sloc: python: 9,357; makefile: 2
file content (338 lines) | stat: -rw-r--r-- 11,432 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
from enum import IntEnum
from typing import List, Mapping
from collections import deque
from hashlib import sha256

from .common import ByteStreamParser, sha256, write_varint
from .merkle import MerkleTree, element_hash


class ClientCommandCode(IntEnum):
    YIELD = 0x10
    GET_PREIMAGE = 0x40
    GET_MERKLE_LEAF_PROOF = 0x41
    GET_MERKLE_LEAF_INDEX = 0x42
    GET_MORE_ELEMENTS = 0xA0


CCMD_YIELD_MUSIG_PUBNONCE_TAG = 0xFFFFFFFF
CCMD_YIELD_MUSIG_PARTIALSIGNATURE_TAG = 0xFFFFFFFE


class ClientCommand:
    def execute(self, request: bytes) -> bytes:
        raise NotImplementedError("Subclasses should implement this method.")

    @property
    def code(self) -> int:
        raise NotImplementedError("Subclasses should implement this method.")


class YieldCommand(ClientCommand):
    def __init__(self, results: List[bytes]):
        self.results = results

    @property
    def code(self) -> int:
        return ClientCommandCode.YIELD

    def execute(self, request: bytes) -> bytes:
        self.results.append(request[1:])  # only skip the first byte (command code)
        return b""


class GetPreimageCommand(ClientCommand):
    def __init__(self, known_preimages: Mapping[bytes, bytes], queue: "deque[bytes]"):
        self.queue = queue
        self.known_preimages = known_preimages

    @property
    def code(self) -> int:
        return ClientCommandCode.GET_PREIMAGE

    def execute(self, request: bytes) -> bytes:
        req = ByteStreamParser(request[1:])

        if req.read_bytes(1) != b'\0':
            raise RuntimeError(f"Unsupported request: the first byte should be 0")

        req_hash = req.read_bytes(32)
        req.assert_empty()

        if req_hash in self.known_preimages:
            known_preimage = self.known_preimages[req_hash]

            preimage_len_out = write_varint(len(known_preimage))

            # We can send at most 255 - len(preimage_len_out) - 1 bytes in a single message;
            # the rest will be stored for GET_MORE_ELEMENTS

            max_payload_size = 255 - len(preimage_len_out) - 1

            payload_size = min(max_payload_size, len(known_preimage))

            if payload_size < len(known_preimage):
                # split into list of length-1 bytes elements
                extra_elements = [
                    known_preimage[i: i + 1]
                    for i in range(payload_size, len(known_preimage))
                ]
                # add to the queue any remaining extra bytes
                self.queue.extend(extra_elements)

            return (
                preimage_len_out
                + payload_size.to_bytes(1, byteorder="big")
                + known_preimage[:payload_size]
            )

        # not found
        raise RuntimeError(f"Requested unknown preimage for: {req_hash.hex()}")


class GetMerkleLeafProofCommand(ClientCommand):
    def __init__(self, known_trees: Mapping[bytes, MerkleTree], queue: "deque[bytes]"):
        self.queue = queue
        self.known_trees = known_trees

    @property
    def code(self) -> int:
        return ClientCommandCode.GET_MERKLE_LEAF_PROOF

    def execute(self, request: bytes) -> bytes:
        req = ByteStreamParser(request[1:])

        root = req.read_bytes(32)
        tree_size = req.read_varint()
        leaf_index = req.read_varint()
        req.assert_empty()

        if not root in self.known_trees:
            raise ValueError(f"Unknown Merkle root: {root.hex()}.")

        mt: MerkleTree = self.known_trees[root]

        if leaf_index >= tree_size or len(mt) != tree_size:
            raise ValueError(f"Invalid index or tree size.")

        if len(self.queue) != 0:
            raise RuntimeError(
                "This command should not execute when the queue is not empty."
            )

        proof = mt.prove_leaf(leaf_index)

        # Compute how many elements we can fit in 255 - 32 - 1 - 1 = 221 bytes
        n_response_elements = min((255 - 32 - 1 - 1) // 32, len(proof))
        n_leftover_elements = len(proof) - n_response_elements

        # Add to the queue any proof elements that do not fit the response
        if (n_leftover_elements > 0):
            self.queue.extend(proof[-n_leftover_elements:])

        return b"".join(
            [
                mt.get(leaf_index),
                len(proof).to_bytes(1, byteorder="big"),
                n_response_elements.to_bytes(1, byteorder="big"),
                *proof[:n_response_elements],
            ]
        )


class GetMerkleLeafIndexCommand(ClientCommand):
    def __init__(self, known_trees: Mapping[bytes, MerkleTree]):
        self.known_trees = known_trees

    @property
    def code(self) -> int:
        return ClientCommandCode.GET_MERKLE_LEAF_INDEX

    def execute(self, request: bytes) -> bytes:
        req = ByteStreamParser(request[1:])

        root = req.read_bytes(32)
        leaf_hash = req.read_bytes(32)
        req.assert_empty()

        if root not in self.known_trees:
            raise ValueError(f"Unknown Merkle root: {root.hex()}.")

        try:
            leaf_index = self.known_trees[root].leaf_index(leaf_hash)
            found = 1
        except ValueError:
            leaf_index = 0
            found = 0

        return found.to_bytes(1, byteorder="big") + write_varint(leaf_index)


class GetMoreElementsCommand(ClientCommand):
    def __init__(self, queue: "deque[bytes]"):
        self.queue = queue

    @property
    def code(self) -> int:
        return ClientCommandCode.GET_MORE_ELEMENTS

    def execute(self, request: bytes) -> bytes:
        if len(request) != 1:
            raise ValueError("Wrong request length.")

        if len(self.queue) == 0:
            raise ValueError("No elements to get.")

        element_len = len(self.queue[0])
        if any(len(el) != element_len for el in self.queue):
            raise ValueError(
                "The queue contains elements of different byte length, which is not expected."
            )

        # pop from the queue, keeping the total response length at most 255

        response_elements = bytearray()

        n_added_elements = 0
        while len(self.queue) > 0 and len(response_elements) + element_len <= 253:
            response_elements.extend(self.queue.popleft())
            n_added_elements += 1

        return b"".join(
            [
                n_added_elements.to_bytes(1, byteorder="big"),
                element_len.to_bytes(1, byteorder="big"),
                bytes(response_elements),
            ]
        )


class ClientCommandInterpreter:
    """Interpreter for the client-side commands.

    This class keeps has methods to keep track of:
    - known preimages
    - known Merkle trees from lists of elements

    Moreover, it containes the state that is relevant for the interpreted client side commands:
    - a queue of bytes that contains any bytes that could not fit in a response from the
      GET_PREIMAGE client command (when a preimage is too long to fit in a single message) or the
      GET_MERKLE_LEAF_PROOF command (which returns a Merkle proof, which might be too long to fit
      in a single message). The data in the queue is returned in one (or more) successive
      GET_MORE_ELEMENTS commands from the hardware wallet.

    Finally, it keeps track of the yielded values (that is, the values sent from the hardware
    wallet with a YIELD client command).

    Attributes
    ----------
    yielded: list[bytes]
        A list of all the value sent by the Hardware Wallet with a YIELD client command during thw
        processing of an APDU.
    """

    def __init__(self):
        self.known_preimages: Mapping[bytes, bytes] = {}
        self.known_trees: Mapping[bytes, MerkleTree] = {}

        self.yielded: List[bytes] = []

        queue = deque()

        commands = [
            YieldCommand(self.yielded),
            GetPreimageCommand(self.known_preimages, queue),
            GetMerkleLeafIndexCommand(self.known_trees),
            GetMerkleLeafProofCommand(self.known_trees, queue),
            GetMoreElementsCommand(queue),
        ]

        self.commands = {cmd.code: cmd for cmd in commands}

    def execute(self, hw_response: bytes) -> bytes:
        """Interprets the client command requested by the hardware wallet, returning the appropriate
        response and updating the client interpreter's internal state if needed.

        Parameters
        ----------
        hw_response : bytes
            The data content of the SW_INTERRUPTED_EXECUTION sent by the hardware wallet.

        Returns
        -------
        bytes
            The result of the execution of the appropriate client side command, containing the response
            to be sent via INS_CONTINUE.
        """

        if len(hw_response) == 0:
            raise RuntimeError(
                "Unexpected empty SW_INTERRUPTED_EXECUTION response from hardware wallet."
            )

        cmd_code = hw_response[0]
        if cmd_code not in self.commands:
            raise RuntimeError(
                "Unexpected command code: 0x{:02X}".format(cmd_code)
            )

        return self.commands[cmd_code].execute(hw_response)

    def add_known_preimage(self, element: bytes) -> None:
        """Adds a preimage to the list of known preimages.

        The client must respond with `element` when a GET_PREIMAGE command is sent with
        `sha256(element)` in its request.

        Parameters
        ----------
        element : bytes
            An array of bytes whose preimage must be known to the client during an APDU execution.
        """

        self.known_preimages[sha256(element)] = element

    def add_known_list(self, elements: List[bytes]) -> None:
        """Adds a known Merkleized list.

        Builds the Merkle tree of `elements`, and adds it to the Merkle trees known to the client
        (mapped by Merkle root `mt_root`).
        moreover, adds all the leafs (after adding the b'\0' prefix) to the list of known preimages.

        If `el` is one of `elements`, the client must respond with b'\0' + `el` when a GET_PREIMAGE
        client command is sent with `sha256(b'\0' + el)`.
        Moreover, the commands GET_MERKLE_LEAF_INDEX and GET_MERKLE_LEAF_PROOF must correctly answer
        queries relative to the Merkle whose root is `mt_root`.

        Parameters
        ----------
        elements : List[bytes]
            A list of `bytes` corresponding to the leafs of the Merkle tree.
        """

        for el in elements:
            self.add_known_preimage(b"\x00" + el)

        mt = MerkleTree(element_hash(el) for el in elements)

        self.known_trees[mt.root] = mt

    def add_known_mapping(self, mapping: Mapping[bytes, bytes]) -> None:
        """Adds the Merkle trees of keys, and the Merkle tree of values (ordered by key)
        of a mapping of bytes to bytes.

        Adds the Merkle tree of the list of keys, and the Merkle tree of the list of corresponding
        values, with the same semantics as the `add_known_list` applied separately to the two lists. 

        Parameters
        ----------
        mapping : Mapping[bytes, bytes]
            A mapping whose keys and values are `bytes`.
        """

        items_sorted = list(sorted(mapping.items()))

        keys = [i[0] for i in items_sorted]
        values = [i[1] for i in items_sorted]
        self.add_known_list(keys)
        self.add_known_list(values)