1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
|
# Copyright Amethyst Reese
# Licensed under the MIT license
from collections.abc import Coroutine, Generator
from contextlib import AbstractAsyncContextManager
from functools import wraps
from typing import Any, Callable, TypeVar
from .cursor import Cursor
_T = TypeVar("_T")
class Result(AbstractAsyncContextManager[_T], Coroutine[Any, Any, _T]):
__slots__ = ("_coro", "_obj")
def __init__(self, coro: Coroutine[Any, Any, _T]):
self._coro = coro
self._obj: _T
def send(self, value) -> None:
return self._coro.send(value)
def throw(self, typ, val=None, tb=None) -> None:
if val is None:
return self._coro.throw(typ)
if tb is None:
return self._coro.throw(typ, val)
return self._coro.throw(typ, val, tb)
def close(self) -> None:
return self._coro.close()
def __await__(self) -> Generator[Any, None, _T]:
return self._coro.__await__()
async def __aenter__(self) -> _T:
self._obj = await self._coro
return self._obj
async def __aexit__(self, exc_type, exc, tb) -> None:
if isinstance(self._obj, Cursor):
await self._obj.close()
def contextmanager(
method: Callable[..., Coroutine[Any, Any, _T]]
) -> Callable[..., Result[_T]]:
@wraps(method)
def wrapper(self, *args, **kwargs) -> Result[_T]:
return Result(method(self, *args, **kwargs))
return wrapper
|