1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
import ssl
COLON = b':'
COMMA = b','
ZERO = b'0'
ZERO_ORD = ord(ZERO)
class NetstringException(Exception):
pass
class WantRead(NetstringException):
pass
class InappropriateParserState(NetstringException):
pass
class ParseError(NetstringException):
pass
class IncompleteNetstring(ParseError):
pass
class TooLong(ParseError):
pass
class BadLength(ParseError):
pass
class BadTerminator(ParseError):
pass
class SingleNetstringFetcher:
def __init__(self, incoming, maxlen=-1):
self._incoming = incoming
self._maxlen = maxlen
self._len_known = False
self._len = None
self._done = False
self._length_bytes = b''
def done(self):
return self._done
def pending(self):
return self._len is not None
def read(self, nbytes=65536):
# pylint: disable=too-many-branches
if not self._len_known:
# reading length
while True:
symbol = self._incoming.read(1)
if not symbol:
raise WantRead()
if symbol == COLON:
if self._len is None:
raise BadLength("No netstring length digits seen.")
self._len_known = True
break
if not symbol.isdigit():
raise BadLength("Non-digit symbol in netstring length.")
val = ord(symbol) - ZERO_ORD
self._len = val if self._len is None else self._len * 10 + val
if self._maxlen != -1 and self._len > self._maxlen:
raise TooLong("Netstring length is over limit.")
# reading data
if self._len:
buf = self._incoming.read(min(nbytes, self._len))
if not buf:
raise WantRead()
self._len -= len(buf)
return buf
else:
if not self._done:
symbol = self._incoming.read(1)
if not symbol:
raise WantRead()
if symbol == COMMA:
self._done = True
else:
raise BadTerminator("Bad netstring terminator.")
return b''
class StreamReader:
""" Async Netstring protocol decoder with interface
alike to ssl.SSLObject BIO interface.
next_string() method returns SingleNetstringFetcher class which
fetches parts of netstring.
SingleNestringFetcher.read() returns b'' in case of string end or raises
WantRead exception when StreamReader needs to be filled with additional
data. Parsing errors signalized with exceptions subclassing ParseError"""
def __init__(self, maxlen=-1):
""" Creates StreamReader instance.
Params:
maxlen - maximal allowed netstring length.
"""
self._maxlen = maxlen
self._incoming = ssl.MemoryBIO()
self._fetcher = None
def pending(self):
return self._fetcher is not None and self._fetcher.pending()
def feed(self, data):
self._incoming.write(data)
def next_string(self):
if self._fetcher is not None and not self._fetcher.done():
raise InappropriateParserState("next_string() invoked while "
"previous fetcher is not exhausted")
self._fetcher = SingleNetstringFetcher(self._incoming, self._maxlen)
return self._fetcher
def encode(data):
return b'%d:%s,' % (len(data), data)
def decode(data):
reader = StreamReader()
reader.feed(data)
try:
while True:
res = []
string_reader = reader.next_string()
while True:
buf = string_reader.read()
if not buf:
break
res.append(buf)
yield b''.join(res)
except WantRead:
if reader.pending():
# pylint: disable=raise-missing-from
raise IncompleteNetstring("Input ends on unfinished string.")
|