File: fakes.py

package info (click to toggle)
python-firebase-messaging 0.4.4-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 652 kB
  • sloc: python: 1,454; makefile: 14
file content (127 lines) | stat: -rw-r--r-- 4,213 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import asyncio
import struct
from unittest.mock import patch

# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout

from firebase_messaging.fcmpushclient import MCS_MESSAGE_TAG, MCS_VERSION, FcmPushClient
from firebase_messaging.proto.checkin_pb2 import AndroidCheckinResponse
from firebase_messaging.proto.mcs_pb2 import LoginResponse


class FakeMcsEndpoint:
    def __init__(self):
        # self.connection_mock = patch(
        #    "asyncio.open_connection", side_effect=self.open_connection, autospec=True
        # )
        # self.connection_mock.start()
        self.client_writer = self.FakeWriter()
        self.client_reader = self.FakeReader()

    def close(self):
        # self.connection_mock.stop()
        pass

    async def open_connection(self, *_, **__):
        # Queues should be created on the loop that will be accessing them
        self.client_writer = self.FakeWriter()
        self.client_reader = self.FakeReader()
        return self.client_reader, self.client_writer

    async def put_message(self, message):
        await self.client_reader.put_message(message)

    async def put_error(self, error):
        await self.client_reader.put_error(error)

    async def get_message(self):
        return await self.client_writer.get_message()

    class FakeReader:
        def __init__(self):
            self.queue = asyncio.Queue()
            self.lock = asyncio.Lock()

        async def readexactly(self, size):
            if size == 0:
                return b""
            val = await self.queue.get()
            if isinstance(val, BaseException):
                raise val
            else:
                for _ in range(1, size):
                    val += await self.queue.get()
                return val

        async def put_message(self, message):
            include_version = isinstance(message, LoginResponse)
            packet = FcmPushClient._make_packet(message, include_version)
            async with self.lock:
                for p in packet:
                    b = bytes([p])
                    await self.queue.put(b)

        async def put_error(self, error):
            async with self.lock:
                await self.queue.put(error)

    class FakeWriter:
        def __init__(self):
            self.queue = asyncio.Queue()
            self.buf = ""
            self.lock = asyncio.Lock()

        def write(self, buffer):
            for i in buffer:
                b = bytes([i])
                self.queue.put_nowait(b)

        async def drain(self):
            pass

        def close(self):
            pass

        async def wait_closed(self):
            pass

        async def get_bytes(self, size):
            async with self.lock:
                val = b""
                for _ in range(size):
                    val += await self.queue.get()
                return val

        async def get_message(self, timeout=2):
            async with asyncio_timeout(timeout):
                r = await self.get_bytes(1)
                (b,) = struct.unpack("B", r)
                if b == MCS_VERSION:  # first message
                    r = await self.get_bytes(1)
                    (b,) = struct.unpack("B", r)
                tag = b
                size = await self._read_varint32()
                msgstr = await self.get_bytes(size)
                msg_class = next(
                    iter([c for c, t in MCS_MESSAGE_TAG.items() if t == tag])
                )
                msg = msg_class()
                msg.ParseFromString(msgstr)
                return msg

        # protobuf variable length integers are encoded in base 128
        # each byte contains 7 bits of the integer and the msb is set if there's
        # more. pretty simple to implement
        async def _read_varint32(self):
            res = 0
            shift = 0
            while True:
                r = await self.get_bytes(1)
                (b,) = struct.unpack("B", r)
                res |= (b & 0x7F) << shift
                if (b & 0x80) == 0:
                    break
                shift += 7
            return res