1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
import asyncio
import struct
from unittest.mock import patch
# When support for cpython older than 3.11 is dropped
# async_timeout can be replaced with asyncio.timeout
from async_timeout import timeout as asyncio_timeout
from firebase_messaging.fcmpushclient import MCS_MESSAGE_TAG, MCS_VERSION, FcmPushClient
from firebase_messaging.proto.checkin_pb2 import AndroidCheckinResponse
from firebase_messaging.proto.mcs_pb2 import LoginResponse
class FakeMcsEndpoint:
def __init__(self):
# self.connection_mock = patch(
# "asyncio.open_connection", side_effect=self.open_connection, autospec=True
# )
# self.connection_mock.start()
self.client_writer = self.FakeWriter()
self.client_reader = self.FakeReader()
def close(self):
# self.connection_mock.stop()
pass
async def open_connection(self, *_, **__):
# Queues should be created on the loop that will be accessing them
self.client_writer = self.FakeWriter()
self.client_reader = self.FakeReader()
return self.client_reader, self.client_writer
async def put_message(self, message):
await self.client_reader.put_message(message)
async def put_error(self, error):
await self.client_reader.put_error(error)
async def get_message(self):
return await self.client_writer.get_message()
class FakeReader:
def __init__(self):
self.queue = asyncio.Queue()
self.lock = asyncio.Lock()
async def readexactly(self, size):
if size == 0:
return b""
val = await self.queue.get()
if isinstance(val, BaseException):
raise val
else:
for _ in range(1, size):
val += await self.queue.get()
return val
async def put_message(self, message):
include_version = isinstance(message, LoginResponse)
packet = FcmPushClient._make_packet(message, include_version)
async with self.lock:
for p in packet:
b = bytes([p])
await self.queue.put(b)
async def put_error(self, error):
async with self.lock:
await self.queue.put(error)
class FakeWriter:
def __init__(self):
self.queue = asyncio.Queue()
self.buf = ""
self.lock = asyncio.Lock()
def write(self, buffer):
for i in buffer:
b = bytes([i])
self.queue.put_nowait(b)
async def drain(self):
pass
def close(self):
pass
async def wait_closed(self):
pass
async def get_bytes(self, size):
async with self.lock:
val = b""
for _ in range(size):
val += await self.queue.get()
return val
async def get_message(self, timeout=2):
async with asyncio_timeout(timeout):
r = await self.get_bytes(1)
(b,) = struct.unpack("B", r)
if b == MCS_VERSION: # first message
r = await self.get_bytes(1)
(b,) = struct.unpack("B", r)
tag = b
size = await self._read_varint32()
msgstr = await self.get_bytes(size)
msg_class = next(
iter([c for c, t in MCS_MESSAGE_TAG.items() if t == tag])
)
msg = msg_class()
msg.ParseFromString(msgstr)
return msg
# protobuf variable length integers are encoded in base 128
# each byte contains 7 bits of the integer and the msb is set if there's
# more. pretty simple to implement
async def _read_varint32(self):
res = 0
shift = 0
while True:
r = await self.get_bytes(1)
(b,) = struct.unpack("B", r)
res |= (b & 0x7F) << shift
if (b & 0x80) == 0:
break
shift += 7
return res
|