1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
|
"""Patches for async socket functionality."""
from contextlib import asynccontextmanager
from unittest.mock import patch
try:
from unittest.mock import AsyncMock
except ImportError:
from unittest.mock import MagicMock
class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)
def async_mock_open(read_data=""):
class AsyncMockFile:
def __init__(self, read_data):
self.read_data = read_data
_async_mock_open.written = read_data[:0]
async def read(self, size=-1):
if size == -1:
ret = self.read_data
self.read_data = self.read_data[:0]
return ret
n = min(size, len(self.read_data))
ret = self.read_data[:n]
self.read_data = self.read_data[n:]
return ret
async def write(self, b):
if _async_mock_open.written:
_async_mock_open.written += b
else:
_async_mock_open.written = b
@asynccontextmanager
async def _async_mock_open(*args, **kwargs):
try:
yield AsyncMockFile(read_data)
finally:
pass
return _async_mock_open
class FakeStreamWriter:
def close(self):
pass
async def wait_closed(self):
pass
def write(self, data):
pass
async def drain(self):
pass
class FakeStreamReader:
async def read(self, numbytes):
return b'TEST'
def async_patch(*args, **kwargs):
return patch(*args, new_callable=AsyncMock, **kwargs)
|