1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
|
# ruff: noqa: PYI036, SLF001, ARG001
"""Utilities for async/sync interoperability in Advanced Alchemy.
This module provides utilities for converting between async and sync functions,
managing concurrency limits, and handling context managers. Used primarily
for adapter implementations that need to support both sync and async patterns.
"""
import asyncio
import concurrent.futures
import functools
import inspect
import sys
import threading
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Coroutine
from types import TracebackType
try:
import uvloop # pyright: ignore[reportMissingImports]
except ImportError:
uvloop = None # type: ignore[assignment,unused-ignore]
class _ThreadLocalState:
"""Thread-local state for tracking context manager state.
Uses typed attributes instead of dynamic attribute access for MyPyC compatibility.
"""
__slots__ = ("in_thread_consistent_context",)
def __init__(self) -> None:
self.in_thread_consistent_context: bool = False
# Thread-local storage to track when we're in a thread-consistent context
_thread_local = threading.local()
def _get_thread_state() -> _ThreadLocalState:
"""Get or create thread-local state.
Returns:
Thread-local state object with typed attributes.
"""
try:
return _thread_local.state # type: ignore[no-any-return]
except AttributeError:
state = _ThreadLocalState()
_thread_local.state = state
return state
ReturnT = TypeVar("ReturnT")
ParamSpecT = ParamSpec("ParamSpecT")
T = TypeVar("T")
class NoValue:
"""Sentinel class for missing values."""
NO_VALUE = NoValue()
class CapacityLimiter:
"""Limits the number of concurrent operations using a semaphore."""
def __init__(self, total_tokens: int) -> None:
"""Initialize the capacity limiter.
Args:
total_tokens: Maximum number of concurrent operations allowed
"""
self._total_tokens = total_tokens
self._semaphore_instance: Optional[asyncio.Semaphore] = None
@property
def _semaphore(self) -> asyncio.Semaphore:
"""Lazy initialization of asyncio.Semaphore for Python 3.9 compatibility."""
if self._semaphore_instance is None:
self._semaphore_instance = asyncio.Semaphore(self._total_tokens)
return self._semaphore_instance
async def acquire(self) -> None:
"""Acquire a token from the semaphore."""
await self._semaphore.acquire()
def release(self) -> None:
"""Release a token back to the semaphore."""
self._semaphore.release()
@property
def total_tokens(self) -> int:
"""Get the number of tokens currently available."""
if self._semaphore_instance is None:
return self._total_tokens
return self._semaphore_instance._value
@total_tokens.setter
def total_tokens(self, value: int) -> None:
self._total_tokens = value
self._semaphore_instance = None
async def __aenter__(self) -> None:
"""Async context manager entry."""
await self.acquire()
async def __aexit__(
self,
exc_type: "Optional[type[BaseException]]",
exc_val: "Optional[BaseException]",
exc_tb: "Optional[TracebackType]",
) -> None:
"""Async context manager exit."""
self.release()
_default_limiter = CapacityLimiter(15)
def run_(async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]") -> "Callable[ParamSpecT, ReturnT]":
"""Convert an async function to a blocking function using asyncio.run().
Args:
async_function: The async function to convert.
Returns:
A blocking function that runs the async function.
"""
@functools.wraps(async_function)
def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
partial_f = functools.partial(async_function, *args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None:
if loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, partial_f())
return future.result()
else:
return asyncio.run(partial_f())
if uvloop and sys.platform != "win32":
uvloop.install() # pyright: ignore[reportUnknownMemberType]
return asyncio.run(partial_f())
return wrapper
def await_(
async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]", raise_sync_error: bool = True
) -> "Callable[ParamSpecT, ReturnT]":
"""Convert an async function to a blocking one, running in the main async loop.
Args:
async_function: The async function to convert.
raise_sync_error: If False, runs in a new event loop if no loop is present.
If True (default), raises RuntimeError if no loop is running.
Returns:
A blocking function that runs the async function.
"""
@functools.wraps(async_function)
def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
partial_f = functools.partial(async_function, *args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
if raise_sync_error:
msg = "await_ called without a running event loop and raise_sync_error=True"
raise RuntimeError(msg) from None
return asyncio.run(partial_f())
else:
if loop.is_running():
try:
current_task = asyncio.current_task(loop=loop)
except RuntimeError:
current_task = None
if current_task is not None:
# This is a workaround for sync-over-async calls from within a running loop.
# It creates a future and then manually drives the event loop
# until that future is resolved. This is not ideal and uses a
# private API (`_run_once`), but it avoids deadlocking the loop.
task = asyncio.ensure_future(partial_f(), loop=loop)
while not task.done() and loop.is_running():
loop._run_once() # type: ignore[attr-defined]
return task.result()
future = asyncio.run_coroutine_threadsafe(partial_f(), loop)
return future.result()
if raise_sync_error:
msg = "await_ found a non-running loop via get_running_loop()"
raise RuntimeError(msg)
return asyncio.run(partial_f())
return wrapper
def async_(
function: "Callable[ParamSpecT, ReturnT]", *, limiter: "Optional[CapacityLimiter]" = None
) -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
"""Convert a blocking function to an async one using asyncio.to_thread().
Args:
function: The blocking function to convert.
limiter: Limit the total number of threads.
Returns:
An async function that runs the original function in a thread.
"""
@functools.wraps(function)
async def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
partial_f = functools.partial(function, *args, **kwargs)
used_limiter = limiter or _default_limiter
async with used_limiter:
return await asyncio.to_thread(partial_f)
return wrapper
def ensure_async_(
function: "Callable[ParamSpecT, Union[Awaitable[ReturnT], ReturnT]]",
) -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
"""Convert a function to an async one if it is not already.
Args:
function: The function to convert.
Returns:
An async function that runs the original function.
"""
if inspect.iscoroutinefunction(function):
return function
@functools.wraps(function)
async def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
result = function(*args, **kwargs)
if inspect.isawaitable(result):
return await result
# Check if we're in an async context already
try:
# If we can get the current event loop, we're in async context
_ = asyncio.get_running_loop()
state = _get_thread_state()
if state.in_thread_consistent_context:
return result
except RuntimeError:
# No event loop, need to run in thread
return await async_(lambda: result)()
return result
return wrapper
class _ContextManagerWrapper(Generic[T]):
def __init__(self, cm: AbstractContextManager[T]) -> None:
self._cm = cm
self._executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
async def __aenter__(self) -> T:
# Use a single thread executor to ensure same thread for enter/exit
loop = asyncio.get_running_loop()
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def _enter_with_flag() -> T:
# Set thread-local flag to indicate we're in a thread-consistent context
state = _get_thread_state()
state.in_thread_consistent_context = True
return self._cm.__enter__()
future = loop.run_in_executor(self._executor, _enter_with_flag)
return await future
async def __aexit__(
self,
exc_type: "Optional[type[BaseException]]",
exc_val: "Optional[BaseException]",
exc_tb: "Optional[TracebackType]",
) -> "Optional[bool]":
# Use the same executor to ensure same thread
if self._executor is None:
# Fallback to any thread if executor wasn't created
return await asyncio.to_thread(self._cm.__exit__, exc_type, exc_val, exc_tb)
loop = asyncio.get_running_loop()
try:
def _exit_with_flag_clear() -> "Optional[bool]":
try:
return self._cm.__exit__(exc_type, exc_val, exc_tb)
finally:
# Clear thread-local flag when exiting
state = _get_thread_state()
state.in_thread_consistent_context = False
future = loop.run_in_executor(self._executor, _exit_with_flag_clear)
return await future
finally:
# Clean up the executor
self._executor.shutdown(wait=False)
self._executor = None
def with_ensure_async_(
obj: "Union[AbstractContextManager[T], AbstractAsyncContextManager[T]]",
) -> "AbstractAsyncContextManager[T]":
"""Convert a context manager to an async one if it is not already.
Args:
obj: The context manager to convert.
Returns:
An async context manager that runs the original context manager.
"""
if isinstance(obj, AbstractContextManager):
return cast("AbstractAsyncContextManager[T]", _ContextManagerWrapper(obj))
return obj
async def get_next(iterable: Any, default: Any = NO_VALUE, *args: Any) -> Any: # pragma: no cover
"""Return the next item from an async iterator.
Args:
iterable: An async iterable.
default: An optional default value to return if the iterable is empty.
*args: The remaining args
Returns:
The next value of the iterable.
Raises:
StopAsyncIteration: The iterable given is not async.
"""
has_default = bool(not isinstance(default, NoValue))
try:
return await iterable.__anext__()
except StopAsyncIteration as exc:
if has_default:
return default
raise StopAsyncIteration from exc
|