1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
|
import asyncio
from async_property.cached import AsyncCachedPropertyDescriptor
is_coroutine = asyncio.iscoroutinefunction
AWAIT_LOADER_ATTR = '_async_property_loaders'
def get_loaders(instance):
return getattr(instance, AWAIT_LOADER_ATTR, ())
class AwaitLoaderMeta(type):
def __new__(mcs, name, bases, attrs) -> type:
loaders = {}
for key, value in attrs.items():
if isinstance(value, AsyncCachedPropertyDescriptor):
loaders[key] = value.get_loader
for base in reversed(bases):
for field, get_loader in get_loaders(base):
if field not in loaders:
loaders[field] = get_loader
attrs[AWAIT_LOADER_ATTR] = tuple(loaders.items())
return super().__new__(mcs, name, bases, attrs)
class AwaitLoader(metaclass=AwaitLoaderMeta):
def __await__(self):
return self._load().__await__()
async def _load(self):
"""
Calls overridable async load method
and then calls async property loaders
"""
if hasattr(self, 'load') and is_coroutine(self.load):
await self.load()
loaders = get_loaders(self)
if loaders:
await asyncio.wait([
asyncio.create_task(get_loader(self)())
for field, get_loader
in loaders
])
return self
|