1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
try:
from contextlib import asynccontextmanager
except ImportError:
asynccontextmanager = lambda func: func
from unittest.mock import patch
from adb_shell import constants
from adb_shell.adb_message import AdbMessage, unpack
from adb_shell.transport.tcp_transport_async import TcpTransportAsync
try:
from unittest.mock import AsyncMock
except ImportError:
from unittest.mock import MagicMock
class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)
def async_mock_open(read_data=""):
class AsyncMockFile:
def __init__(self, read_data):
self.read_data = read_data
_async_mock_open.written = read_data[:0]
async def read(self, size=-1):
if size == -1:
ret = self.read_data
self.read_data = self.read_data[:0]
return ret
n = min(size, len(self.read_data))
ret = self.read_data[:n]
self.read_data = self.read_data[n:]
return ret
async def write(self, b):
if _async_mock_open.written:
_async_mock_open.written += b
else:
_async_mock_open.written = b
def fileno(self):
return 123
@asynccontextmanager
async def _async_mock_open(*args, **kwargs):
try:
yield AsyncMockFile(read_data)
finally:
pass
return _async_mock_open
class FakeStreamWriter:
def close(self):
pass
async def wait_closed(self):
pass
def write(self, data):
pass
async def drain(self):
pass
class FakeStreamReader:
async def read(self, numbytes):
return b'TEST'
class FakeTcpTransportAsync(TcpTransportAsync):
def __init__(self, *args, **kwargs):
TcpTransportAsync.__init__(self, *args, **kwargs)
self.bulk_read_data = b''
self.bulk_write_data = b''
async def close(self):
self._reader = None
self._writer = None
async def connect(self, transport_timeout_s=None):
self._reader = True
self._writer = True
async def bulk_read(self, numbytes, transport_timeout_s=None):
num = min(numbytes, constants.MAX_ADB_DATA)
ret = self.bulk_read_data[:num]
self.bulk_read_data = self.bulk_read_data[num:]
return ret
async def bulk_write(self, data, transport_timeout_s=None):
self.bulk_write_data += data
return len(data)
# `TcpTransport` patches
PATCH_TCP_TRANSPORT_ASYNC = patch('adb_shell.adb_device_async.TcpTransportAsync', FakeTcpTransportAsync)
def async_patch(*args, **kwargs):
return patch(*args, new_callable=AsyncMock, **kwargs)
|