File: protocol.py

package info (click to toggle)
pyro5 5.15-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,112 kB
  • sloc: python: 14,291; makefile: 163; sh: 66; javascript: 62
file content (201 lines) | stat: -rw-r--r-- 8,853 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""
The pyro wire protocol structures.

Pyro - Python Remote Objects.  Copyright by Irmen de Jong (irmen@razorvine.net).


Wire messages contains of a fixed size header, an optional set of annotation chunks,
and then the payload data. This class doesn't deal with the payload data:
(de)serialization and handling of that data is done elsewhere.
Annotation chunks are only parsed.

The header format is::

    0x00   4s  4   'PYRO' (message identifier)
    0x04   H   2   protocol version
    0x06   B   1   message type
    0x07   B   1   serializer id
    0x08   H   2   message flags
    0x0a   H   2   sequence number   (to identify proper request-reply sequencing)
    0x0c   I   4   data length   (max 4 Gb)
    0x10   I   4   annotations length (max 4 Gb, total of all chunks, 0 if no annotation chunks present)
    0x14   16s 16  correlation uuid
    0x24   H   2   (reserved)
    0x26   H   2   magic number 0x4dc5
    total size: 0x28 (40 bytes)

After the header, zero or more annotation chunks may follow, of the format::

    4s  4   annotation id (4 ASCII letters)
    I   4   chunk length  (max 4 Gb)
    B   x   annotation chunk databytes

After that, the actual payload data bytes follow.
"""

import struct
import logging
import zlib
import uuid
from . import config, errors
from .callcontext import current_context


log = logging.getLogger("Pyro5.protocol")

MSG_CONNECT = 1
MSG_CONNECTOK = 2
MSG_CONNECTFAIL = 3
MSG_INVOKE = 4
MSG_RESULT = 5
MSG_PING = 6
FLAGS_EXCEPTION = 1 << 0
FLAGS_COMPRESSED = 1 << 1    # compress the data, but not the annotations (if you need that, do it yourself)
FLAGS_ONEWAY = 1 << 2
FLAGS_BATCH = 1 << 3
FLAGS_ITEMSTREAMRESULT = 1 << 4
FLAGS_KEEPSERIALIZED = 1 << 5
FLAGS_CORR_ID = 1 << 6

# wire protocol version. Note that if this gets updated, Pyrolite might need an update too.
PROTOCOL_VERSION = 502
_magic_number = 0x4dc5
_header_format = '!4sHBBHHII16sHH'
_header_size = struct.calcsize(_header_format)
_magic_number_bytes = _magic_number.to_bytes(2, "big")
_protocol_version_bytes = PROTOCOL_VERSION.to_bytes(2, "big")
_empty_correlation_id = b"\0" * 16


class SendingMessage:
    """Wire protocol message that will be sent."""

    def __init__(self, msgtype, flags, seq, serializer_id, payload, annotations=None):
        self.type = msgtype
        self.seq = seq
        self.serializer_id = serializer_id
        annotations = annotations or {}
        annotations_size = sum([8 + len(v) for v in annotations.values()])
        flags &= ~FLAGS_COMPRESSED
        if config.COMPRESSION and len(payload) > 100:
            payload = zlib.compress(payload, 4)
            flags |= FLAGS_COMPRESSED
        self.flags = flags
        total_size = len(payload) + annotations_size
        if total_size > config.MAX_MESSAGE_SIZE:
            raise errors.ProtocolError("message too large ({:d}, max={:d})".format(total_size, config.MAX_MESSAGE_SIZE))
        if current_context.correlation_id:
            flags |= FLAGS_CORR_ID
            self.corr_id = current_context.correlation_id.bytes
        else:
            self.corr_id = _empty_correlation_id
        header_data = struct.pack(_header_format, b"PYRO", PROTOCOL_VERSION, msgtype, serializer_id, flags, seq,
                                  len(payload), annotations_size, self.corr_id, 0, _magic_number)
        annotation_data = []
        for k, v in annotations.items():
            if len(k) != 4:
                raise errors.ProtocolError("annotation identifier must be 4 ascii characters")
            annotation_data.append(struct.pack("!4sI", k.encode("ascii"), len(v)))
            if not isinstance(v, (bytes, bytearray, memoryview)):
                raise errors.ProtocolError("annotation data must be bytes, bytearray, or memoryview", type(v))
            annotation_data.append(v)    # note: annotations are not compressed by Pyro
        self.data = header_data + b"".join(annotation_data) + payload

    def __repr__(self):
        return "<{:s}.{:s} at 0x{:x}; type={:d} flags={:d} seq={:d} size={:d}>" \
            .format(self.__module__, self.__class__.__name__, id(self), self.type, self.flags, self.seq, len(self.data))

    @staticmethod
    def ping(pyroConnection):
        """Convenience method to send a 'ping' message and wait for the 'pong' response"""
        ping = SendingMessage(MSG_PING, 0, 0, 42, b"ping")
        pyroConnection.send(ping.data)
        recv_stub(pyroConnection, [MSG_PING])


class ReceivingMessage:
    """Wire protocol message that was received."""
    def __init__(self, header, payload=None):
        """Parses a message from the given header."""
        tag, ver, self.type, self.serializer_id, self.flags, self.seq, self.data_size, \
            self.annotations_size, self.corr_id,  _, magic = struct.unpack(_header_format, header)
        if tag != b"PYRO" or ver != PROTOCOL_VERSION or magic != _magic_number:
            raise errors.ProtocolError("invalid message or protocol version")
        if self.data_size+self.annotations_size > config.MAX_MESSAGE_SIZE:
            raise errors.ProtocolError("message too large ({:d}, max={:d})"
                                       .format(self.data_size+self.annotations_size, config.MAX_MESSAGE_SIZE))
        self.data = None
        self.annotations = {}
        if payload is not None:
            self.add_payload(payload)

    def __repr__(self):
        return "<{:s}.{:s} at 0x{:x}; type={:d} flags={:d} seq={:d} size={:d}>" \
            .format(self.__module__, self.__class__.__name__, id(self), self.type, self.flags, self.seq, len(self.data or ""))

    @staticmethod
    def validate(data):
        """Checks if the message data looks like a valid Pyro message, if not, raise an error."""
        ld = len(data)
        if ld < 4:
            raise ValueError("data must be at least 4 bytes to be able to identify")
        if not data.startswith(b"PYRO"):
            raise errors.ProtocolError("invalid data")
        if ld >= 6 and data[4:6] != _protocol_version_bytes:
            raise errors.ProtocolError("invalid protocol version: {:d}".format(int.from_bytes(data[4:6], "big")))
        if ld >= _header_size and data[38:40] != _magic_number_bytes:
            raise errors.ProtocolError("invalid magic number")

    def add_payload(self, payload):
        """Parses (annotations processing) and adds payload data to a received message."""
        assert not self.data
        if len(payload) != self.data_size + self.annotations_size:
            raise errors.ProtocolError("payload length doesn't match message header")
        if self.annotations_size:
            payload = memoryview(payload)  # avoid copying
            self.annotations = {}
            i = 0
            while i < self.annotations_size:
                annotation_id = bytes(payload[i:i+4]).decode("ascii")
                length = int.from_bytes(payload[i+4:i+8], "big")
                self.annotations[annotation_id] = payload[i+8:i+8+length]     # note: it stores a memoryview!
                i += 8 + length
            assert i == self.annotations_size
            self.data = payload[self.annotations_size:]
        else:
            self.data = payload
        if self.flags & FLAGS_COMPRESSED:
            self.data = zlib.decompress(self.data)
            self.flags &= ~FLAGS_COMPRESSED
            self.data_size = len(self.data)


def log_wiredata(logger, text, msg):
    """logs all the given properties of the wire message in the given logger"""
    num_anns = len(msg.annotations) if hasattr(msg, "annotations") else 0
    corr_bytes = bytes(msg.corr_id) if hasattr(msg, "corr_id") else _empty_correlation_id
    corr_id = uuid.UUID(bytes=corr_bytes)
    logger.debug("%s: msgtype=%d flags=0x%x ser=%d seq=%d num_annotations=%s corr_id=%s\ndata=%r" %
                 (text, msg.type, msg.flags, msg.serializer_id, msg.seq, num_anns, corr_id, bytes(msg.data)))


def recv_stub(connection, accepted_msgtypes=None):
    """
    Receives a pyro message from a given connection.
    Accepts the given message types (None=any, or pass a sequence).
    Also reads annotation chunks and the actual payload data.
    """
    # TODO decouple i/o from actual protocol logic, so that the protocol can be easily unit tested
    header = connection.recv(6)  # 'PYRO' + 2 bytes protocol version
    ReceivingMessage.validate(header)
    header += connection.recv(_header_size - 6)
    msg = ReceivingMessage(header)
    if accepted_msgtypes and msg.type not in accepted_msgtypes:
        err = "invalid msg type {:d} received (expected: {:s})".format(msg.type, ",".join(str(t) for t in accepted_msgtypes))
        log.error(err)
        exc = errors.ProtocolError(err)
        exc.pyroMsg = msg
        raise exc
    payload = connection.recv(msg.annotations_size + msg.data_size)
    msg.add_payload(payload)
    return msg