File: test_future.py

package info (click to toggle)
pyzmq 27.1.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,984 kB
  • sloc: python: 15,189; ansic: 285; makefile: 169; sh: 85
file content (354 lines) | stat: -rw-r--r-- 10,613 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# Copyright (c) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.

import json
import os
import sys
from datetime import timedelta

import pytest

gen = pytest.importorskip('tornado.gen')

from tornado.ioloop import IOLoop

import zmq
from zmq.eventloop import future
from zmq_test_utils import BaseZMQTestCase


class TestFutureSocket(BaseZMQTestCase):
    Context = future.Context

    def setUp(self):
        self.loop = IOLoop(make_current=False)
        super().setUp()

    def tearDown(self):
        super().tearDown()
        if self.loop:
            self.loop.close(all_fds=True)

    def test_socket_class(self):
        s = self.context.socket(zmq.PUSH)
        assert isinstance(s, future.Socket)
        s.close()

    def test_instance_subclass_first(self):
        actx = self.Context.instance()
        ctx = zmq.Context.instance()
        ctx.term()
        actx.term()
        assert type(ctx) is zmq.Context
        assert type(actx) is self.Context

    def test_instance_subclass_second(self):
        ctx = zmq.Context.instance()
        actx = self.Context.instance()
        ctx.term()
        actx.term()
        assert type(ctx) is zmq.Context
        assert type(actx) is self.Context

    def test_recv_multipart(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            f = b.recv_multipart()
            assert not f.done()
            await a.send(b"hi")
            recvd = await f
            assert recvd == [b'hi']

        self.loop.run_sync(test)

    def test_recv(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            f1 = b.recv()
            f2 = b.recv()
            assert not f1.done()
            assert not f2.done()
            await a.send_multipart([b"hi", b"there"])
            recvd = await f2
            assert f1.done()
            assert f1.result() == b'hi'
            assert recvd == b'there'

        self.loop.run_sync(test)

    def test_recv_cancel(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            f1 = b.recv()
            f2 = b.recv_multipart()
            assert f1.cancel()
            assert f1.done()
            assert not f2.done()
            await a.send_multipart([b"hi", b"there"])
            recvd = await f2
            assert f1.cancelled()
            assert f2.done()
            assert recvd == [b'hi', b'there']

        self.loop.run_sync(test)

    @pytest.mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
    def test_recv_timeout(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            b.rcvtimeo = 100
            f1 = b.recv()
            b.rcvtimeo = 1000
            f2 = b.recv_multipart()
            with pytest.raises(zmq.Again):
                await f1
            await a.send_multipart([b"hi", b"there"])
            recvd = await f2
            assert f2.done()
            assert recvd == [b'hi', b'there']

        self.loop.run_sync(test)

    @pytest.mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
    def test_send_timeout(self):
        async def test():
            s = self.socket(zmq.PUSH)
            s.sndtimeo = 100
            with pytest.raises(zmq.Again):
                await s.send(b"not going anywhere")

        self.loop.run_sync(test)

    def test_send_noblock(self):
        async def test():
            s = self.socket(zmq.PUSH)
            with pytest.raises(zmq.Again):
                await s.send(b"not going anywhere", flags=zmq.NOBLOCK)

        self.loop.run_sync(test)

    def test_send_multipart_noblock(self):
        async def test():
            s = self.socket(zmq.PUSH)
            with pytest.raises(zmq.Again):
                await s.send_multipart([b"not going anywhere"], flags=zmq.NOBLOCK)

        self.loop.run_sync(test)

    def test_recv_string(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            f = b.recv_string()
            assert not f.done()
            msg = 'πøøπ'
            await a.send_string(msg)
            recvd = await f
            assert f.done()
            assert f.result() == msg
            assert recvd == msg

        self.loop.run_sync(test)

    def test_recv_json(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            f = b.recv_json()
            assert not f.done()
            obj = dict(a=5)
            await a.send_json(obj)
            recvd = await f
            assert f.done()
            assert f.result() == obj
            assert recvd == obj

        self.loop.run_sync(test)

    def test_recv_json_cancelled(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            f = b.recv_json()
            assert not f.done()
            f.cancel()
            # cycle eventloop to allow cancel events to fire
            await gen.sleep(0)
            obj = dict(a=5)
            await a.send_json(obj)
            with pytest.raises(future.CancelledError):
                recvd = await f
            assert f.done()
            # give it a chance to incorrectly consume the event
            events = await b.poll(timeout=5)
            assert events
            await gen.sleep(0)
            # make sure cancelled recv didn't eat up event
            recvd = await gen.with_timeout(timedelta(seconds=5), b.recv_json())
            assert recvd == obj

        self.loop.run_sync(test)

    def test_recv_pyobj(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            f = b.recv_pyobj()
            assert not f.done()
            obj = dict(a=5)
            await a.send_pyobj(obj)
            recvd = await f
            assert f.done()
            assert f.result() == obj
            assert recvd == obj

        self.loop.run_sync(test)

    def test_custom_serialize(self):
        def serialize(msg):
            frames = []
            frames.extend(msg.get('identities', []))
            content = json.dumps(msg['content']).encode('utf8')
            frames.append(content)
            return frames

        def deserialize(frames):
            identities = frames[:-1]
            content = json.loads(frames[-1].decode('utf8'))
            return {
                'identities': identities,
                'content': content,
            }

        async def test():
            a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)

            msg = {
                'content': {
                    'a': 5,
                    'b': 'bee',
                }
            }
            await a.send_serialized(msg, serialize)
            recvd = await b.recv_serialized(deserialize)
            assert recvd['content'] == msg['content']
            assert recvd['identities']
            # bounce back, tests identities
            await b.send_serialized(recvd, serialize)
            r2 = await a.recv_serialized(deserialize)
            assert r2['content'] == msg['content']
            assert not r2['identities']

        self.loop.run_sync(test)

    def test_custom_serialize_error(self):
        async def test():
            a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)

            msg = {
                'content': {
                    'a': 5,
                    'b': 'bee',
                }
            }
            with pytest.raises(TypeError):
                await a.send_serialized(json, json.dumps)

            await a.send(b"not json")
            with pytest.raises(TypeError):
                await b.recv_serialized(json.loads)

        self.loop.run_sync(test)

    def test_poll(self):
        async def test():
            a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
            f = b.poll(timeout=0)
            assert f.done()
            assert f.result() == 0

            f = b.poll(timeout=1)
            assert not f.done()
            evt = await f
            assert evt == 0

            f = b.poll(timeout=1000)
            assert not f.done()
            await a.send_multipart([b"hi", b"there"])
            evt = await f
            assert evt == zmq.POLLIN
            recvd = await b.recv_multipart()
            assert recvd == [b'hi', b'there']

        self.loop.run_sync(test)

    @pytest.mark.skipif(
        sys.platform.startswith('win'), reason='Windows unsupported socket type'
    )
    def test_poll_base_socket(self):
        async def test():
            ctx = zmq.Context()
            url = 'inproc://test'
            a = ctx.socket(zmq.PUSH)
            b = ctx.socket(zmq.PULL)
            self.sockets.extend([a, b])
            a.bind(url)
            b.connect(url)

            poller = future.Poller()
            poller.register(b, zmq.POLLIN)

            f = poller.poll(timeout=1000)
            assert not f.done()
            a.send_multipart([b'hi', b'there'])
            evt = await f
            assert evt == [(b, zmq.POLLIN)]
            recvd = b.recv_multipart()
            assert recvd == [b'hi', b'there']
            a.close()
            b.close()
            ctx.term()

        self.loop.run_sync(test)

    def test_close_all_fds(self):
        s = self.socket(zmq.PUB)

        async def attach():
            s._get_loop()

        self.loop.run_sync(attach)
        self.loop.close(all_fds=True)
        self.loop = None  # avoid second close later
        assert s.closed

    @pytest.mark.skipif(
        sys.platform.startswith('win'),
        reason='Windows does not support polling on files',
    )
    def test_poll_raw(self):
        async def test():
            p = future.Poller()
            # make a pipe
            r, w = os.pipe()
            r = os.fdopen(r, 'rb')
            w = os.fdopen(w, 'wb')

            # POLLOUT
            p.register(r, zmq.POLLIN)
            p.register(w, zmq.POLLOUT)
            evts = await p.poll(timeout=1)
            evts = dict(evts)
            assert r.fileno() not in evts
            assert w.fileno() in evts
            assert evts[w.fileno()] == zmq.POLLOUT

            # POLLIN
            p.unregister(w)
            w.write(b'x')
            w.flush()
            evts = await p.poll(timeout=1000)
            evts = dict(evts)
            assert r.fileno() in evts
            assert evts[r.fileno()] == zmq.POLLIN
            assert r.read(1) == b'x'
            r.close()
            w.close()

        self.loop.run_sync(test)