1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
import contextlib
import inject
from test import BaseTestInject
class Destroyable:
def __init__(self):
self.started = True
def destroy(self):
self.started = False
class MockFile(Destroyable):
...
class MockConnection(Destroyable):
...
class MockFoo(Destroyable):
...
@contextlib.contextmanager
def get_file_sync():
obj = MockFile()
yield obj
obj.destroy()
@contextlib.contextmanager
def get_conn_sync():
obj = MockConnection()
yield obj
obj.destroy()
@contextlib.contextmanager
def get_foo_sync():
obj = MockFoo()
yield obj
obj.destroy()
@contextlib.asynccontextmanager
async def get_file_async():
obj = MockFile()
yield obj
obj.destroy()
@contextlib.asynccontextmanager
async def get_conn_async():
obj = MockConnection()
yield obj
obj.destroy()
class TestContextManagerFunctional(BaseTestInject):
def test_provider_as_context_manager_sync(self):
def config(binder):
binder.bind_to_provider(MockFile, get_file_sync)
binder.bind(int, 100)
binder.bind_to_provider(str, lambda: "Hello")
binder.bind_to_provider(MockConnection, get_conn_sync)
inject.configure(config)
@inject.autoparams()
def mock_func(conn: MockConnection, name: str, f: MockFile, number: int):
assert f.started
assert conn.started
assert name == "Hello"
assert number == 100
return f, conn
f_, conn_ = mock_func()
assert not f_.started
assert not conn_.started
def test_provider_as_context_manager_async(self):
def config(binder):
binder.bind_to_provider(MockFile, get_file_async)
binder.bind(int, 100)
binder.bind_to_provider(str, lambda: "Hello")
binder.bind_to_provider(MockConnection, get_conn_async)
binder.bind_to_provider(MockFoo, get_foo_sync)
inject.configure(config)
@inject.autoparams()
async def mock_func(conn: MockConnection, name: str, f: MockFile, number: int, foo: MockFoo):
assert f.started
assert conn.started
assert foo.started
assert name == "Hello"
assert number == 100
return f, conn, foo
f_, conn_, foo_ = self.run_async(mock_func())
assert not f_.started
assert not conn_.started
assert not foo_.started
|