"""
Provide asynchronous programming utilities.
"""
from __future__ import annotations
from asyncio import TaskGroup, get_running_loop, run
from functools import wraps
from threading import Thread
from typing import (
Callable,
Awaitable,
TypeVar,
Generic,
cast,
ParamSpec,
Coroutine,
Any,
)
from betty.warnings import deprecated
P = ParamSpec("P")
T = TypeVar("T")
[docs]
async def gather(*coroutines: Coroutine[Any, None, T]) -> tuple[T, ...]:
"""
Gather multiple coroutines.
This is like Python's own ``asyncio.gather``, but with improved error handling.
"""
tasks = []
async with TaskGroup() as task_group:
for coroutine in coroutines:
tasks.append(task_group.create_task(coroutine))
return tuple(task.result() for task in tasks)
[docs]
@deprecated(
"This function is deprecated as of Betty 0.3.3, and will be removed in Betty 0.4.x. Instead, use `betty.asyncio.wait_to_thread()` or `asyncio.run()`."
)
def wait(f: Awaitable[T]) -> T:
"""
Wait for an awaitable, either in a new event loop or another thread.
"""
try:
loop = get_running_loop()
except RuntimeError:
loop = None
if loop:
return wait_to_thread(f)
else:
return run(
f, # type: ignore[arg-type]
)
[docs]
def wait_to_thread(f: Awaitable[T]) -> T:
"""
Wait for an awaitable in another thread.
"""
synced = _WaiterThread(f)
synced.start()
synced.join()
return synced.return_value
[docs]
@deprecated(
"This function is deprecated as of Betty 0.3.3, and will be removed in Betty 0.4.x. Instead, use `betty.asyncio.wait_to_thread()` or `asyncio.run()`."
)
def sync(f: Callable[P, Awaitable[T]]) -> Callable[P, T]:
"""
Decorate an asynchronous callable to become synchronous.
"""
@wraps(f)
def _synced(*args: P.args, **kwargs: P.kwargs) -> T:
return wait(f(*args, **kwargs))
return _synced
class _WaiterThread(Thread, Generic[T]):
def __init__(self, awaitable: Awaitable[T]):
super().__init__()
self._awaitable = awaitable
self._return_value: T | None = None
self._e: BaseException | None = None
@property
def return_value(self) -> T:
if self._e:
raise self._e
return cast(T, self._return_value)
def run(self) -> None:
run(self._run())
async def _run(self) -> None:
try:
self._return_value = await self._awaitable
except BaseException as e: # noqa: B036
# Store the exception, so it can be reraised when the calling thread
# gets self.return_value.
self._e = e