File: netstring.py

package info (click to toggle)
postfix-mta-sts-resolver 1.5.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 536 kB
  • sloc: python: 3,069; sh: 226; makefile: 47
file content (149 lines) | stat: -rw-r--r-- 4,078 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import ssl

COLON = b':'
COMMA = b','
ZERO = b'0'
ZERO_ORD = ord(ZERO)

class NetstringException(Exception):
    pass


class WantRead(NetstringException):
    pass


class InappropriateParserState(NetstringException):
    pass


class ParseError(NetstringException):
    pass


class IncompleteNetstring(ParseError):
    pass


class TooLong(ParseError):
    pass


class BadLength(ParseError):
    pass


class BadTerminator(ParseError):
    pass


class SingleNetstringFetcher:
    def __init__(self, incoming, maxlen=-1):
        self._incoming = incoming
        self._maxlen = maxlen
        self._len_known = False
        self._len = None
        self._done = False
        self._length_bytes = b''

    def done(self):
        return self._done

    def pending(self):
        return self._len is not None

    def read(self, nbytes=65536):
        # pylint: disable=too-many-branches
        if not self._len_known:
            # reading length
            while True:
                symbol = self._incoming.read(1)
                if not symbol:
                    raise WantRead()
                if symbol == COLON:
                    if self._len is None:
                        raise BadLength("No netstring length digits seen.")
                    self._len_known = True
                    break
                if not symbol.isdigit():
                    raise BadLength("Non-digit symbol in netstring length.")
                val = ord(symbol) - ZERO_ORD
                self._len = val if self._len is None else self._len * 10 + val
                if self._maxlen != -1 and self._len > self._maxlen:
                    raise TooLong("Netstring length is over limit.")
        # reading data
        if self._len:
            buf = self._incoming.read(min(nbytes, self._len))
            if not buf:
                raise WantRead()
            self._len -= len(buf)
            return buf
        else:
            if not self._done:
                symbol = self._incoming.read(1)
                if not symbol:
                    raise WantRead()
                if symbol == COMMA:
                    self._done = True
                else:
                    raise BadTerminator("Bad netstring terminator.")
            return b''


class StreamReader:
    """ Async Netstring protocol decoder with interface
    alike to ssl.SSLObject BIO interface.

    next_string() method returns SingleNetstringFetcher class which
    fetches parts of netstring.

    SingleNestringFetcher.read() returns b'' in case of string end or raises
    WantRead exception when StreamReader needs to be filled with additional
    data. Parsing errors signalized with exceptions subclassing ParseError"""

    def __init__(self, maxlen=-1):
        """ Creates StreamReader instance.

        Params:

        maxlen - maximal allowed netstring length.
        """
        self._maxlen = maxlen
        self._incoming = ssl.MemoryBIO()
        self._fetcher = None

    def pending(self):
        return self._fetcher is not None and self._fetcher.pending()

    def feed(self, data):
        self._incoming.write(data)

    def next_string(self):
        if self._fetcher is not None and not self._fetcher.done():
            raise InappropriateParserState("next_string() invoked while "
                                           "previous fetcher is not exhausted")
        self._fetcher = SingleNetstringFetcher(self._incoming, self._maxlen)
        return self._fetcher


def encode(data):
    return b'%d:%s,' % (len(data), data)


def decode(data):
    reader = StreamReader()
    reader.feed(data)
    try:
        while True:
            res = []
            string_reader = reader.next_string()
            while True:
                buf = string_reader.read()
                if not buf:
                    break
                res.append(buf)
            yield b''.join(res)
    except WantRead:
        if reader.pending():
            # pylint: disable=raise-missing-from
            raise IncompleteNetstring("Input ends on unfinished string.")