1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
# (c) 2018 Mantas Mikulėnas <grawity@gmail.com>
# (c) 2024 E. Castedo Ellerman <castedo@castedo.com>
# Released under the MIT License (https://spdx.org/licenses/MIT)
# fmt: off
from __future__ import annotations
import io
import struct
from typing import Any, BinaryIO, TYPE_CHECKING, cast
if TYPE_CHECKING:
BytesLike = bytes | bytearray | memoryview
class SshReader:
"""All read_ methods may raise ValueError."""
def __init__(self, ins: BinaryIO | BytesLike):
if isinstance(ins, (bytes, bytearray, memoryview)):
ins = io.BytesIO(ins)
self.input_fh = ins
@staticmethod
def from_bytes(buf: BytesLike) -> SshReader:
return SshReader(buf)
def read(self, length: int = -1) -> bytes:
buf = self.input_fh.read(length)
if (not buf) and (length is not None) and (length != 0):
raise ValueError("Unexpected end of input.")
return buf
def read_byte(self) -> int:
return cast(int, self._read_and_unpack(1, "!B"))
def read_uint32(self) -> int:
return cast(int, self._read_and_unpack(4, "!L"))
def read_bool(self) -> bool:
return cast(bool, self._read_and_unpack(1, "!?"))
def read_string(self) -> bytes:
length = self.read_uint32()
return self.read(length)
def read_string_pkt(self) -> SshReader:
return SshReader(self.read_string())
def read_mpint(self) -> int:
buf = self.read_string()
return int.from_bytes(buf, byteorder="big", signed=False)
def _read_and_unpack(self, length: int, frmt: str) -> Any:
try:
return struct.unpack(frmt, self.read(length))[0]
except struct.error as ex:
raise ValueError from ex
def ssh_read_string_pair(buf: BinaryIO | BytesLike) -> tuple[bytes, bytes]:
pkt = SshReader(buf)
return (pkt.read_string(), pkt.read_string())
class SshWriter:
def __init__(self, output_fh: io.BytesIO):
self.output_fh = output_fh
def write(self, b: BytesLike) -> int:
return self.output_fh.write(b)
def flush(self) -> None:
self.output_fh.flush()
def write_byte(self, val: int) -> int:
buf = struct.pack("!B", val)
return self.write(buf)
def write_uint32(self, val: int) -> int:
buf = struct.pack("!L", val)
return self.write(buf)
def write_bool(self, val: bool) -> int:
buf = struct.pack("!?", val)
return self.write(buf)
def write_string(self, val: BytesLike) -> int:
buf = struct.pack("!L", len(val)) + val
return self.write(buf)
def write_mpint(self, val: int) -> int:
length = val.bit_length()
if length & 0xFF:
length |= 0xFF
length += 1
length >>= 8
buf = val.to_bytes(length, "big", signed=False)
return self.write_string(buf)
|