1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
from typing import BinaryIO, Union
def encode_to_null_terminated(string: str, codec: str = "utf-8") -> bytes:
"""
>>> encode_to_null_terminated("las files are cool")
b'las files are cool\\x00'
>>> encode_to_null_terminated("")
b'\\x00'
"""
b = string.encode(codec)
if not b or b[-1] != 0:
b += b"\0"
return b
def read_string(
stream: BinaryIO, length: int, encoding: str = "ascii "
) -> Union[str, bytes]:
"""
Reads `length` bytes from the stream, and tries to decode it.
If the decoding succeeds, returns the `str`. Otherwise the raw bytes
are returned.
"""
raw_string = stream.read(length)
first_null_byte_pos = raw_string.find(b"\0")
if first_null_byte_pos >= 0:
raw_string = raw_string[:first_null_byte_pos]
try:
return raw_string.decode(encoding)
except UnicodeDecodeError:
return raw_string
def write_as_c_string(
stream: BinaryIO,
string: Union[str, bytes],
max_length: int,
encoding: str = "ascii",
encoding_errors: str = "strict",
) -> bool:
"""
Writes the string or bytes as a 'C' string to the stream.
A 'C' string is null terminated, so this function writes the null
terminator.
It will always write `max_length` bytes to the stream,
so the input data may be null padded or truncated.
"""
raw_bytes = get_bytes_from_string(string, encoding, encoding_errors)
raw_bytes, was_truncated = null_pad_bytes(
raw_bytes, max_length, null_terminate=True
)
stream.write(raw_bytes)
return was_truncated
def write_string(
stream: BinaryIO,
string: Union[str, bytes],
max_length: int,
encoding: str = "ascii",
encoding_errors: str = "strict",
) -> bool:
"""
Writes the string or bytes as a 'C' string to the stream.
Written data is not null terminated.
It will always write `max_length` bytes to the stream,
so the input data may be null padded or truncated.
"""
raw_bytes = get_bytes_from_string(string, encoding, encoding_errors)
raw_bytes, was_truncated = null_pad_bytes(
raw_bytes, max_length, null_terminate=False
)
stream.write(raw_bytes)
return was_truncated
def get_bytes_from_string(
string: Union[str, bytes], encoding: str, encoding_errors: str
) -> bytes:
if isinstance(string, str):
raw_bytes = string.encode(encoding, errors=encoding_errors)
else:
# check that the bytes are valid for the given encoding
_ = string.decode(encoding, errors=encoding_errors)
raw_bytes = string
return raw_bytes
def null_pad_bytes(
raw_bytes: bytes, max_length: int, null_terminate: bool = True
) -> (bytes, bool):
"""
Returns a byte string of `max_length` bytes.
If the input bytes is shorter then the output will be null padded.
If the input bytes is longer it will be truncated.
If null_terminate is True, then the last byte is guaranteed to be a null
byte (and the out string sill has `max_length` bytes).
>>> null_pad_bytes(b'abcd', 5)
(b'abcd\\x00', False)
# input has 4 bytes, and must be 4 bytes long
# but since null_terminate is True its guaranteed to be null terminated,
# the last byte will be truncated
>>> null_pad_bytes(b'abcd', 4)
(b'abc\\x00', True)
# Same setup, but don't null terminate
>>> null_pad_bytes(b'abcd', 4, null_terminate=False)
(b'abcd', False)
>>> null_pad_bytes(b'abcdef', 4)
(b'abc\\x00', True)
>>> null_pad_bytes(b'abcdef', 4, null_terminate=False)
(b'abcd', True)
>>> null_pad_bytes(b'abcd', 10)
(b'abcd\\x00\\x00\\x00\\x00\\x00\\x00', False)
>>> null_pad_bytes(b'abcdabcd', 5)
(b'abcd\\x00', True)
>>> null_pad_bytes(b'abcde\\x00', 8)
(b'abcde\\x00\\x00\\x00', False)
>>> null_pad_bytes(b'abcde\\x00z', 8)
(b'abcde\\x00\\x00\\x00', True)
"""
was_truncated = False
null_pos = raw_bytes.find(b"\0")
if null_pos != -1:
was_truncated = null_pos != len(raw_bytes) - 1
raw_bytes = raw_bytes[:null_pos]
if len(raw_bytes) >= max_length + (not null_terminate):
raw_bytes = raw_bytes[: max_length - 1 + (not null_terminate)]
was_truncated = True
# This will effectively null pad
raw_bytes = raw_bytes.ljust(max_length, b"\0")
return raw_bytes, was_truncated
|