File: test_message.py

package info (click to toggle)
pyro4 4.82-2
  • links: PTS
  • area: main
  • in suites: bookworm
  • size: 2,528 kB
  • sloc: python: 17,736; makefile: 169; sh: 113; javascript: 62
file content (224 lines) | stat: -rw-r--r-- 9,808 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
Tests for pyro write protocol message.

Pyro - Python Remote Objects.  Copyright by Irmen de Jong (irmen@razorvine.net).
"""

import hashlib
import hmac
import unittest
import zlib
import Pyro4.message
import Pyro4.constants
import Pyro4.util
import Pyro4.errors
from Pyro4.message import Message
from Pyro4.configuration import config
from testsupport import ConnectionMock


def pyrohmac(data, hmac_key, annotations={}):
    mac = hmac.new(hmac_key, data, digestmod=hashlib.sha1)
    for k, v in annotations.items():
        if k != "HMAC":
            mac.update(v)
    return mac.digest()


class MessageTestsHmac(unittest.TestCase):
    def setUp(self):
        self.ser = Pyro4.util.get_serializer(config.SERIALIZER)

    def testMessage(self):
        Message(99, b"", self.ser.serializer_id, 0, 0, hmac_key=b"secret")  # doesn't check msg type here
        self.assertRaises(Pyro4.errors.ProtocolError, Message.from_header, "FOOBAR")
        msg = Message(Pyro4.message.MSG_CONNECT, b"hello", self.ser.serializer_id, 0, 0, hmac_key=b"secret")
        self.assertEqual(Pyro4.message.MSG_CONNECT, msg.type)
        self.assertEqual(5, msg.data_size)
        self.assertEqual(b"hello", msg.data)
        self.assertEqual(4 + 2 + 20, msg.annotations_size)
        mac = pyrohmac(b"hello", b"secret", msg.annotations)
        self.assertDictEqual({"HMAC": mac}, msg.annotations)

        hdr = msg.to_bytes()[:24]
        msg = Message.from_header(hdr)
        self.assertEqual(Pyro4.message.MSG_CONNECT, msg.type)
        self.assertEqual(4 + 2 + 20, msg.annotations_size)
        self.assertEqual(5, msg.data_size)

        hdr = Message(Pyro4.message.MSG_RESULT, b"", self.ser.serializer_id, 0, 0, hmac_key=b"secret").to_bytes()[:24]
        msg = Message.from_header(hdr)
        self.assertEqual(Pyro4.message.MSG_RESULT, msg.type)
        self.assertEqual(4 + 2 + 20, msg.annotations_size)
        self.assertEqual(0, msg.data_size)

        hdr = Message(Pyro4.message.MSG_RESULT, b"hello", 12345, 60006, 30003, hmac_key=b"secret").to_bytes()[:24]
        msg = Message.from_header(hdr)
        self.assertEqual(Pyro4.message.MSG_RESULT, msg.type)
        self.assertEqual(60006, msg.flags)
        self.assertEqual(5, msg.data_size)
        self.assertEqual(12345, msg.serializer_id)
        self.assertEqual(30003, msg.seq)

        msg = Message(255, b"", self.ser.serializer_id, 0, 255, hmac_key=b"secret").to_bytes()
        self.assertEqual(50, len(msg))
        msg = Message(1, b"", self.ser.serializer_id, 0, 255, hmac_key=b"secret").to_bytes()
        self.assertEqual(50, len(msg))
        msg = Message(1, b"", self.ser.serializer_id, flags=253, seq=254, hmac_key=b"secret").to_bytes()
        self.assertEqual(50, len(msg))

        # compression is a job of the code supplying the data, so the messagefactory should leave it untouched
        data = b"x" * 1000
        msg = Message(Pyro4.message.MSG_INVOKE, data, self.ser.serializer_id, 0, 0, hmac_key=b"secret").to_bytes()
        msg2 = Message(Pyro4.message.MSG_INVOKE, data, self.ser.serializer_id, Pyro4.message.FLAGS_COMPRESSED, 0, hmac_key=b"secret").to_bytes()
        self.assertEqual(len(msg), len(msg2))

    def testMessageHeaderDatasize(self):
        msg = Message(Pyro4.message.MSG_RESULT, b"hello", 12345, 60006, 30003, hmac_key=b"secret")
        msg.data_size = 0x12345678  # hack it to a large value to see if it comes back ok
        hdr = msg.to_bytes()[:24]
        msg = Message.from_header(hdr)
        self.assertEqual(Pyro4.message.MSG_RESULT, msg.type)
        self.assertEqual(60006, msg.flags)
        self.assertEqual(0x12345678, msg.data_size)
        self.assertEqual(12345, msg.serializer_id)
        self.assertEqual(30003, msg.seq)

    def testAnnotations(self):
        annotations = {"TEST": b"abcde"}
        msg = Message(Pyro4.message.MSG_CONNECT, b"hello", self.ser.serializer_id, 0, 0, annotations, b"secret")
        data = msg.to_bytes()
        annotations_size = 4 + 2 + 20 + 4 + 2 + 5
        self.assertEqual(msg.header_size + 5 + annotations_size, len(data))
        self.assertEqual(annotations_size, msg.annotations_size)
        self.assertEqual(2, len(msg.annotations))
        self.assertEqual(b"abcde", msg.annotations["TEST"])
        mac = pyrohmac(b"hello", b"secret", annotations)
        self.assertEqual(mac, msg.annotations["HMAC"])

    def testAnnotationsIdLength4(self):
        try:
            msg = Message(Pyro4.message.MSG_CONNECT, b"hello", self.ser.serializer_id, 0, 0, {"TOOLONG": b"abcde"}, b"secret")
            _ = msg.to_bytes()
            self.fail("should fail, too long")
        except Pyro4.errors.ProtocolError:
            pass
        try:
            msg = Message(Pyro4.message.MSG_CONNECT, b"hello", self.ser.serializer_id, 0, 0, {"QQ": b"abcde"}, b"secret")
            _ = msg.to_bytes()
            self.fail("should fail, too short")
        except Pyro4.errors.ProtocolError:
            pass

    def testRecvAnnotations(self):
        annotations = {"TEST": b"abcde"}
        msg = Message(Pyro4.message.MSG_CONNECT, b"hello", self.ser.serializer_id, 0, 0, annotations, b"secret")
        c = ConnectionMock()
        c.send(msg.to_bytes())
        msg = Message.recv(c, hmac_key=b"secret")
        self.assertEqual(0, len(c.received))
        self.assertEqual(5, msg.data_size)
        self.assertEqual(b"hello", msg.data)
        self.assertEqual(b"abcde", msg.annotations["TEST"])
        self.assertIn("HMAC", msg.annotations)

    def testProtocolVersion(self):
        version = Pyro4.constants.PROTOCOL_VERSION
        Pyro4.constants.PROTOCOL_VERSION = 0  # fake invalid protocol version number
        msg = Message(Pyro4.message.MSG_RESULT, b"", self.ser.serializer_id, 0, 1, hmac_key=b"secret").to_bytes()
        Pyro4.constants.PROTOCOL_VERSION = version
        self.assertRaises(Pyro4.errors.ProtocolError, Message.from_header, msg)

    def testHmac(self):
        data = Message(Pyro4.message.MSG_RESULT, b"test", 42, 0, 1, hmac_key=b"test key").to_bytes()
        c = ConnectionMock(data)
        # test checking of different hmacs
        try:
            Message.recv(c, hmac_key=None)
            self.fail("crash expected")
        except Pyro4.errors.SecurityError as x:
            self.assertIn("hmac key config", str(x))
        c = ConnectionMock(data)
        try:
            Message.recv(c, hmac_key=b"T3ST-K3Y")
            self.fail("crash expected")
        except Pyro4.errors.SecurityError as x:
            self.assertIn("hmac", str(x))
        # test that it works again when providing the correct key
        c = ConnectionMock(data)
        msg = Message.recv(c, hmac_key=b"test key")
        self.assertEqual(b"test key", msg.hmac_key)

    def testHmacMethod(self):
        data = Message(Pyro4.message.MSG_RESULT, b"test", 42, 0, 1, hmac_key=b"test key")
        digest = data.hmac()
        self.assertTrue(len(digest) > 10)
        data = Message(Pyro4.message.MSG_RESULT, b"test", 42, 0, 1)
        with self.assertRaises(TypeError):
            data.hmac()

    def testSecureCompare(self):
        self.assertFalse(Pyro4.message.secure_compare("apple", "banana"))
        self.assertFalse(Pyro4.message.secure_compare(b"apple", b"banana"))
        self.assertTrue(Pyro4.message.secure_compare("apple", "apple"))
        self.assertTrue(Pyro4.message.secure_compare(b"apple", b"apple"))
        with self.assertRaises(TypeError):
            Pyro4.message.secure_compare(999, "typemismatch")

    def testChecksum(self):
        msg = Message(Pyro4.message.MSG_RESULT, b"test", 42, 0, 1, hmac_key=b"secret")
        c = ConnectionMock()
        c.send(msg.to_bytes())
        # corrupt the checksum bytes
        data = c.received
        data = data[:msg.header_size - 2] + b'\x00\x00' + data[msg.header_size:]
        c = ConnectionMock(data)
        try:
            Message.recv(c)
            self.fail("crash expected")
        except Pyro4.errors.ProtocolError as x:
            self.assertIn("checksum", str(x))

    def testCompression(self):
        data = b"The quick brown fox jumps over the lazy dog."*10
        compressed_data = zlib.compress(data)
        flags = Pyro4.message.FLAGS_COMPRESSED
        msg = Message(Pyro4.message.MSG_INVOKE, compressed_data, 42, flags, 1, hmac_key=b"secret")
        self.assertNotEqual(data, msg.data)
        data_size = msg.data_size
        self.assertLess(data_size, len(data))
        msg.decompress_if_needed()
        self.assertEqual(data, msg.data)
        self.assertEqual(0, msg.flags)
        self.assertGreater(msg.data_size, data_size)


class MessageTestsNoHmac(unittest.TestCase):
    def testRecvNoAnnotations(self):
        msg = Message(Pyro4.message.MSG_CONNECT, b"hello", 42, 0, 0)
        c = ConnectionMock()
        c.send(msg.to_bytes())
        msg = Message.recv(c)
        self.assertEqual(0, len(c.received))
        self.assertEqual(5, msg.data_size)
        self.assertEqual(b"hello", msg.data)
        self.assertEqual(0, msg.annotations_size)
        self.assertEqual(0, len(msg.annotations))

    def testMaxDataSize(self):
        msg = Message(Pyro4.message.MSG_CONNECT, b"hello", 42, 0, 0)
        msg.data_size = 0x7fffffff  # still within 32 bits signed limits
        msg.to_bytes()
        msg.data_size = 0x80000000  # overflow, Pyro has a 2 gigabyte message size limitation
        with self.assertRaises(ValueError) as ex:
            msg.to_bytes()
        self.assertEqual("invalid message size (outside range 0..2Gb)", str(ex.exception))
        msg.data_size = -42
        with self.assertRaises(ValueError) as ex:
            msg.to_bytes()
        self.assertEqual("invalid message size (outside range 0..2Gb)", str(ex.exception))



if __name__ == "__main__":
    unittest.main()