1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
|
from enum import IntEnum
from typing import List, Mapping
from collections import deque
from hashlib import sha256
from .common import ByteStreamParser, sha256, write_varint
from .merkle import MerkleTree, element_hash
class ClientCommandCode(IntEnum):
YIELD = 0x10
GET_PREIMAGE = 0x40
GET_MERKLE_LEAF_PROOF = 0x41
GET_MERKLE_LEAF_INDEX = 0x42
GET_MORE_ELEMENTS = 0xA0
CCMD_YIELD_MUSIG_PUBNONCE_TAG = 0xFFFFFFFF
CCMD_YIELD_MUSIG_PARTIALSIGNATURE_TAG = 0xFFFFFFFE
class ClientCommand:
def execute(self, request: bytes) -> bytes:
raise NotImplementedError("Subclasses should implement this method.")
@property
def code(self) -> int:
raise NotImplementedError("Subclasses should implement this method.")
class YieldCommand(ClientCommand):
def __init__(self, results: List[bytes]):
self.results = results
@property
def code(self) -> int:
return ClientCommandCode.YIELD
def execute(self, request: bytes) -> bytes:
self.results.append(request[1:]) # only skip the first byte (command code)
return b""
class GetPreimageCommand(ClientCommand):
def __init__(self, known_preimages: Mapping[bytes, bytes], queue: "deque[bytes]"):
self.queue = queue
self.known_preimages = known_preimages
@property
def code(self) -> int:
return ClientCommandCode.GET_PREIMAGE
def execute(self, request: bytes) -> bytes:
req = ByteStreamParser(request[1:])
if req.read_bytes(1) != b'\0':
raise RuntimeError(f"Unsupported request: the first byte should be 0")
req_hash = req.read_bytes(32)
req.assert_empty()
if req_hash in self.known_preimages:
known_preimage = self.known_preimages[req_hash]
preimage_len_out = write_varint(len(known_preimage))
# We can send at most 255 - len(preimage_len_out) - 1 bytes in a single message;
# the rest will be stored for GET_MORE_ELEMENTS
max_payload_size = 255 - len(preimage_len_out) - 1
payload_size = min(max_payload_size, len(known_preimage))
if payload_size < len(known_preimage):
# split into list of length-1 bytes elements
extra_elements = [
known_preimage[i: i + 1]
for i in range(payload_size, len(known_preimage))
]
# add to the queue any remaining extra bytes
self.queue.extend(extra_elements)
return (
preimage_len_out
+ payload_size.to_bytes(1, byteorder="big")
+ known_preimage[:payload_size]
)
# not found
raise RuntimeError(f"Requested unknown preimage for: {req_hash.hex()}")
class GetMerkleLeafProofCommand(ClientCommand):
def __init__(self, known_trees: Mapping[bytes, MerkleTree], queue: "deque[bytes]"):
self.queue = queue
self.known_trees = known_trees
@property
def code(self) -> int:
return ClientCommandCode.GET_MERKLE_LEAF_PROOF
def execute(self, request: bytes) -> bytes:
req = ByteStreamParser(request[1:])
root = req.read_bytes(32)
tree_size = req.read_varint()
leaf_index = req.read_varint()
req.assert_empty()
if not root in self.known_trees:
raise ValueError(f"Unknown Merkle root: {root.hex()}.")
mt: MerkleTree = self.known_trees[root]
if leaf_index >= tree_size or len(mt) != tree_size:
raise ValueError(f"Invalid index or tree size.")
if len(self.queue) != 0:
raise RuntimeError(
"This command should not execute when the queue is not empty."
)
proof = mt.prove_leaf(leaf_index)
# Compute how many elements we can fit in 255 - 32 - 1 - 1 = 221 bytes
n_response_elements = min((255 - 32 - 1 - 1) // 32, len(proof))
n_leftover_elements = len(proof) - n_response_elements
# Add to the queue any proof elements that do not fit the response
if (n_leftover_elements > 0):
self.queue.extend(proof[-n_leftover_elements:])
return b"".join(
[
mt.get(leaf_index),
len(proof).to_bytes(1, byteorder="big"),
n_response_elements.to_bytes(1, byteorder="big"),
*proof[:n_response_elements],
]
)
class GetMerkleLeafIndexCommand(ClientCommand):
def __init__(self, known_trees: Mapping[bytes, MerkleTree]):
self.known_trees = known_trees
@property
def code(self) -> int:
return ClientCommandCode.GET_MERKLE_LEAF_INDEX
def execute(self, request: bytes) -> bytes:
req = ByteStreamParser(request[1:])
root = req.read_bytes(32)
leaf_hash = req.read_bytes(32)
req.assert_empty()
if root not in self.known_trees:
raise ValueError(f"Unknown Merkle root: {root.hex()}.")
try:
leaf_index = self.known_trees[root].leaf_index(leaf_hash)
found = 1
except ValueError:
leaf_index = 0
found = 0
return found.to_bytes(1, byteorder="big") + write_varint(leaf_index)
class GetMoreElementsCommand(ClientCommand):
def __init__(self, queue: "deque[bytes]"):
self.queue = queue
@property
def code(self) -> int:
return ClientCommandCode.GET_MORE_ELEMENTS
def execute(self, request: bytes) -> bytes:
if len(request) != 1:
raise ValueError("Wrong request length.")
if len(self.queue) == 0:
raise ValueError("No elements to get.")
element_len = len(self.queue[0])
if any(len(el) != element_len for el in self.queue):
raise ValueError(
"The queue contains elements of different byte length, which is not expected."
)
# pop from the queue, keeping the total response length at most 255
response_elements = bytearray()
n_added_elements = 0
while len(self.queue) > 0 and len(response_elements) + element_len <= 253:
response_elements.extend(self.queue.popleft())
n_added_elements += 1
return b"".join(
[
n_added_elements.to_bytes(1, byteorder="big"),
element_len.to_bytes(1, byteorder="big"),
bytes(response_elements),
]
)
class ClientCommandInterpreter:
"""Interpreter for the client-side commands.
This class keeps has methods to keep track of:
- known preimages
- known Merkle trees from lists of elements
Moreover, it containes the state that is relevant for the interpreted client side commands:
- a queue of bytes that contains any bytes that could not fit in a response from the
GET_PREIMAGE client command (when a preimage is too long to fit in a single message) or the
GET_MERKLE_LEAF_PROOF command (which returns a Merkle proof, which might be too long to fit
in a single message). The data in the queue is returned in one (or more) successive
GET_MORE_ELEMENTS commands from the hardware wallet.
Finally, it keeps track of the yielded values (that is, the values sent from the hardware
wallet with a YIELD client command).
Attributes
----------
yielded: list[bytes]
A list of all the value sent by the Hardware Wallet with a YIELD client command during thw
processing of an APDU.
"""
def __init__(self):
self.known_preimages: Mapping[bytes, bytes] = {}
self.known_trees: Mapping[bytes, MerkleTree] = {}
self.yielded: List[bytes] = []
queue = deque()
commands = [
YieldCommand(self.yielded),
GetPreimageCommand(self.known_preimages, queue),
GetMerkleLeafIndexCommand(self.known_trees),
GetMerkleLeafProofCommand(self.known_trees, queue),
GetMoreElementsCommand(queue),
]
self.commands = {cmd.code: cmd for cmd in commands}
def execute(self, hw_response: bytes) -> bytes:
"""Interprets the client command requested by the hardware wallet, returning the appropriate
response and updating the client interpreter's internal state if needed.
Parameters
----------
hw_response : bytes
The data content of the SW_INTERRUPTED_EXECUTION sent by the hardware wallet.
Returns
-------
bytes
The result of the execution of the appropriate client side command, containing the response
to be sent via INS_CONTINUE.
"""
if len(hw_response) == 0:
raise RuntimeError(
"Unexpected empty SW_INTERRUPTED_EXECUTION response from hardware wallet."
)
cmd_code = hw_response[0]
if cmd_code not in self.commands:
raise RuntimeError(
"Unexpected command code: 0x{:02X}".format(cmd_code)
)
return self.commands[cmd_code].execute(hw_response)
def add_known_preimage(self, element: bytes) -> None:
"""Adds a preimage to the list of known preimages.
The client must respond with `element` when a GET_PREIMAGE command is sent with
`sha256(element)` in its request.
Parameters
----------
element : bytes
An array of bytes whose preimage must be known to the client during an APDU execution.
"""
self.known_preimages[sha256(element)] = element
def add_known_list(self, elements: List[bytes]) -> None:
"""Adds a known Merkleized list.
Builds the Merkle tree of `elements`, and adds it to the Merkle trees known to the client
(mapped by Merkle root `mt_root`).
moreover, adds all the leafs (after adding the b'\0' prefix) to the list of known preimages.
If `el` is one of `elements`, the client must respond with b'\0' + `el` when a GET_PREIMAGE
client command is sent with `sha256(b'\0' + el)`.
Moreover, the commands GET_MERKLE_LEAF_INDEX and GET_MERKLE_LEAF_PROOF must correctly answer
queries relative to the Merkle whose root is `mt_root`.
Parameters
----------
elements : List[bytes]
A list of `bytes` corresponding to the leafs of the Merkle tree.
"""
for el in elements:
self.add_known_preimage(b"\x00" + el)
mt = MerkleTree(element_hash(el) for el in elements)
self.known_trees[mt.root] = mt
def add_known_mapping(self, mapping: Mapping[bytes, bytes]) -> None:
"""Adds the Merkle trees of keys, and the Merkle tree of values (ordered by key)
of a mapping of bytes to bytes.
Adds the Merkle tree of the list of keys, and the Merkle tree of the list of corresponding
values, with the same semantics as the `add_known_list` applied separately to the two lists.
Parameters
----------
mapping : Mapping[bytes, bytes]
A mapping whose keys and values are `bytes`.
"""
items_sorted = list(sorted(mapping.items()))
keys = [i[0] for i in items_sorted]
values = [i[1] for i in items_sorted]
self.add_known_list(keys)
self.add_known_list(values)
|