1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045
|
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import abc
import asyncio
import collections
import inspect
import itertools
import logging
import typing
from contextlib import asynccontextmanager
from typing import (
Any,
AsyncContextManager,
AsyncIterator,
Awaitable,
Callable,
Collection,
Coroutine,
Generator,
Generic,
Hashable,
Iterable,
Literal,
TypeVar,
overload,
)
import attr
from typing_extensions import Concatenate, ParamSpec, Unpack
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_coroutine_in_background,
run_in_background,
)
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)
_T = TypeVar("_T")
class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
"""Abstract base class defining the consumer interface of ObservableDeferred"""
__slots__ = ()
@abc.abstractmethod
def observe(self) -> "defer.Deferred[_T]":
"""Add a new observer for this ObservableDeferred
This returns a brand new deferred that is resolved when the underlying
deferred is resolved. Interacting with the returned deferred does not
effect the underlying deferred.
Note that the returned Deferred doesn't follow the Synapse logcontext rules -
you will probably want to `make_deferred_yieldable` it.
"""
...
class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
deferred.
If consumeErrors is true errors will be captured from the origin deferred.
Cancelling or otherwise resolving an observer will not affect the original
ObservableDeferred.
NB that it does not attempt to do anything with logcontexts; in general
you should probably make_deferred_yieldable the deferreds
returned by `observe`, and ensure that the original deferred runs its
callbacks in the sentinel logcontext.
"""
__slots__ = ["_deferred", "_observers", "_result"]
_deferred: "defer.Deferred[_T]"
_observers: list["defer.Deferred[_T]"] | tuple[()]
_result: None | tuple[Literal[True], _T] | tuple[Literal[False], Failure]
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", [])
def callback(r: _T) -> _T:
object.__setattr__(self, "_result", (True, r))
# once we have set _result, no more entries will be added to _observers,
# so it's safe to replace it with the empty tuple.
observers = self._observers
object.__setattr__(self, "_observers", ())
for observer in observers:
try:
observer.callback(r)
except Exception as e:
logger.exception(
"%r threw an exception on .callback(%r), ignoring...",
observer,
r,
exc_info=e,
)
return r
def errback(f: Failure) -> Failure | None:
object.__setattr__(self, "_result", (False, f))
# once we have set _result, no more entries will be added to _observers,
# so it's safe to replace it with the empty tuple.
observers = self._observers
object.__setattr__(self, "_observers", ())
for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
try:
observer.errback(f)
except Exception as e:
logger.exception(
"%r threw an exception on .errback(%r), ignoring...",
observer,
f,
exc_info=e,
)
if consumeErrors:
return None
else:
return f
deferred.addCallbacks(callback, errback)
def observe(self) -> "defer.Deferred[_T]":
"""Observe the underlying deferred.
This returns a brand new deferred that is resolved when the underlying
deferred is resolved. Interacting with the returned deferred does not
effect the underlying deferred.
"""
if not self._result:
assert isinstance(self._observers, list)
d: "defer.Deferred[_T]" = defer.Deferred()
self._observers.append(d)
return d
elif self._result[0]:
return defer.succeed(self._result[1])
else:
return defer.fail(self._result[1])
def observers(self) -> "Collection[defer.Deferred[_T]]":
return self._observers
def has_called(self) -> bool:
return self._result is not None
def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True
def get_result(self) -> _T | Failure:
if self._result is None:
raise ValueError(f"{self!r} has no result yet")
return self._result[1]
def __getattr__(self, name: str) -> Any:
return getattr(self._deferred, name)
def __setattr__(self, name: str, value: Any) -> None:
setattr(self._deferred, name, value)
def __repr__(self) -> str:
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self),
self._result,
self._deferred,
)
T = TypeVar("T")
async def concurrently_execute(
func: Callable[[T], Any],
args: Iterable[T],
limit: int,
delay_cancellation: bool = False,
) -> None:
"""Executes the function with each argument concurrently while limiting
the number of concurrent executions.
Args:
func: Function to execute, should return a deferred or coroutine.
args: List of arguments to pass to func, each invocation of func
gets a single argument.
limit: Maximum number of conccurent executions.
delay_cancellation: Whether to delay cancellation until after the invocations
have finished.
Returns:
None, when all function invocations have finished. The return values
from those functions are discarded.
"""
it = iter(args)
async def _concurrently_execute_inner(value: T) -> None:
try:
while True:
await maybe_awaitable(func(value))
value = next(it)
except StopIteration:
pass
# We use `itertools.islice` to handle the case where the number of args is
# less than the limit, avoiding needlessly spawning unnecessary background
# tasks.
if delay_cancellation:
await yieldable_gather_results_delaying_cancellation(
_concurrently_execute_inner,
(value for value in itertools.islice(it, limit)),
)
else:
await yieldable_gather_results(
_concurrently_execute_inner,
(value for value in itertools.islice(it, limit)),
)
P = ParamSpec("P")
R = TypeVar("R")
async def yieldable_gather_results(
func: Callable[Concatenate[T, P], Awaitable[R]],
iter: Iterable[T],
*args: P.args,
**kwargs: P.kwargs,
) -> list[R]:
"""Executes the function with each argument concurrently.
Args:
func: Function to execute that returns a Deferred
iter: An iterable that yields items that get passed as the first
argument to the function
*args: Arguments to be passed to each call to func
**kwargs: Keyword arguments to be passed to each call to func
Returns
A list containing the results of the function
"""
try:
return await make_deferred_yieldable(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter],
consumeErrors=True,
)
)
except defer.FirstError as dfe:
# unwrap the error from defer.gatherResults.
# The raised exception's traceback only includes func() etc if
# the 'await' happens before the exception is thrown - ie if the failure
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
# could be large.
#
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
# we could throw Twisted into the fires of Mordor.
# suppress exception chaining, because the FirstError doesn't tell us anything
# very interesting.
assert isinstance(dfe.subFailure.value, BaseException)
raise dfe.subFailure.value from None
async def yieldable_gather_results_delaying_cancellation(
func: Callable[Concatenate[T, P], Awaitable[R]],
iter: Iterable[T],
*args: P.args,
**kwargs: P.kwargs,
) -> list[R]:
"""Executes the function with each argument concurrently.
Cancellation is delayed until after all the results have been gathered.
See `yieldable_gather_results`.
Args:
func: Function to execute that returns a Deferred
iter: An iterable that yields items that get passed as the first
argument to the function
*args: Arguments to be passed to each call to func
**kwargs: Keyword arguments to be passed to each call to func
Returns
A list containing the results of the function
"""
try:
return await make_deferred_yieldable(
delay_cancellation(
defer.gatherResults(
[run_in_background(func, item, *args, **kwargs) for item in iter],
consumeErrors=True,
)
)
)
except defer.FirstError as dfe:
assert isinstance(dfe.subFailure.value, BaseException)
raise dfe.subFailure.value from None
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
@overload
def gather_results(
deferredList: tuple[()], consumeErrors: bool = ...
) -> "defer.Deferred[tuple[()]]": ...
@overload
def gather_results(
deferredList: tuple["defer.Deferred[T1]"],
consumeErrors: bool = ...,
) -> "defer.Deferred[tuple[T1]]": ...
@overload
def gather_results(
deferredList: tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
consumeErrors: bool = ...,
) -> "defer.Deferred[tuple[T1, T2]]": ...
@overload
def gather_results(
deferredList: tuple[
"defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
],
consumeErrors: bool = ...,
) -> "defer.Deferred[tuple[T1, T2, T3]]": ...
@overload
def gather_results(
deferredList: tuple[
"defer.Deferred[T1]",
"defer.Deferred[T2]",
"defer.Deferred[T3]",
"defer.Deferred[T4]",
],
consumeErrors: bool = ...,
) -> "defer.Deferred[tuple[T1, T2, T3, T4]]": ...
def gather_results( # type: ignore[misc]
deferredList: tuple["defer.Deferred[T1]", ...],
consumeErrors: bool = False,
) -> "defer.Deferred[tuple[T1, ...]]":
"""Combines a tuple of `Deferred`s into a single `Deferred`.
Wraps `defer.gatherResults` to provide type annotations that support heterogenous
lists of `Deferred`s.
"""
# The `type: ignore[misc]` above suppresses
# "Overloaded function implementation cannot produce return type of signature 1/2/3"
deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
return deferred.addCallback(tuple)
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None]],
) -> tuple[T1 | None]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
tuple[
Coroutine[Any, Any, T1] | None,
Coroutine[Any, Any, T2] | None,
]
],
) -> tuple[T1 | None, T2 | None]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
tuple[
Coroutine[Any, Any, T1] | None,
Coroutine[Any, Any, T2] | None,
Coroutine[Any, Any, T3] | None,
]
],
) -> tuple[T1 | None, T2 | None, T3 | None]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
tuple[
Coroutine[Any, Any, T1] | None,
Coroutine[Any, Any, T2] | None,
Coroutine[Any, Any, T3] | None,
Coroutine[Any, Any, T4] | None,
]
],
) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
tuple[
Coroutine[Any, Any, T1] | None,
Coroutine[Any, Any, T2] | None,
Coroutine[Any, Any, T3] | None,
Coroutine[Any, Any, T4] | None,
Coroutine[Any, Any, T5] | None,
]
],
) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
tuple[
Coroutine[Any, Any, T1] | None,
Coroutine[Any, Any, T2] | None,
Coroutine[Any, Any, T3] | None,
Coroutine[Any, Any, T4] | None,
Coroutine[Any, Any, T5] | None,
Coroutine[Any, Any, T6] | None,
]
],
) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]],
) -> tuple[T1 | None, ...]:
"""Helper function that allows waiting on multiple coroutines at once.
The return value is a tuple of the return values of the coroutines in order.
If a `None` is passed instead of a coroutine, it will be ignored and a None
is returned in the tuple.
Note: For typechecking we need to have an explicit overload for each
distinct number of coroutines passed in. If you see type problems, it's
likely because you're using many arguments and you need to add a new
overload above.
"""
try:
results = await make_deferred_yieldable(
defer.gatherResults(
[
run_coroutine_in_background(coroutine)
for coroutine in coroutines
if coroutine is not None
],
consumeErrors=True,
)
)
results_iter = iter(results)
return tuple(
next(results_iter) if coroutine is not None else None
for coroutine in coroutines
)
except defer.FirstError as dfe:
# unwrap the error from defer.gatherResults.
# The raised exception's traceback only includes func() etc if
# the 'await' happens before the exception is thrown - ie if the failure
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
# could be large.
#
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
# we could throw Twisted into the fires of Mordor.
# suppress exception chaining, because the FirstError doesn't tell us anything
# very interesting.
assert isinstance(dfe.subFailure.value, BaseException)
raise dfe.subFailure.value from None
@attr.s(slots=True, auto_attribs=True)
class _LinearizerEntry:
# The number of things executing.
count: int
# Deferreds for the things blocked from executing.
deferreds: typing.OrderedDict["defer.Deferred[None]", Literal[1]]
class Linearizer:
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few things happen at a time on a given resource.
Example:
async with limiter.queue("test_key"):
# do some work.
"""
def __init__(
self,
name: str,
clock: Clock,
max_count: int = 1,
):
"""
Args:
name: TODO
max_count: The maximum number of concurrent accesses
clock: (ideally, the homeserver clock `hs.get_clock()`)
"""
self.name = name
self.max_count = max_count
self._clock = clock
# key_to_defer is a map from the key to a _LinearizerEntry.
self.key_to_defer: dict[Hashable, _LinearizerEntry] = {}
def is_queued(self, key: Hashable) -> bool:
"""Checks whether there is a process queued up waiting"""
entry = self.key_to_defer.get(key)
if not entry:
# No entry so nothing is waiting.
return False
# There are waiting deferreds only in the OrderedDict of deferreds is
# non-empty.
return bool(entry.deferreds)
def queue(self, key: Hashable) -> AsyncContextManager[None]:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
entry = await self._acquire_lock(key)
try:
yield
finally:
self._release_lock(key, entry)
return _ctx_manager()
async def _acquire_lock(self, key: Hashable) -> _LinearizerEntry:
"""Acquires a linearizer lock, waiting if necessary.
Returns once we have secured the lock.
"""
entry = self.key_to_defer.setdefault(
key, _LinearizerEntry(0, collections.OrderedDict())
)
if entry.count < self.max_count:
# The number of things executing is less than the maximum.
logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key
)
entry.count += 1
return entry
# Otherwise, the number of things executing is at the maximum and we have to
# add a deferred to the list of blocked items.
# When one of the things currently executing finishes it will callback
# this item so that it can continue executing.
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1
try:
await new_defer
except Exception as e:
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
"Cancelling wait for linearizer lock %r for key %r",
self.name,
key,
)
else:
logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name,
key,
)
# we just have to take ourselves back out of the queue.
del entry.deferreds[new_defer]
raise
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry.count += 1
# if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can
# relatively rapidly lead to stack exhaustion. This is essentially
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
#
# In order to break the cycle, we add a cheeky sleep(0) here to
# ensure that we fall back to the reactor between each iteration.
#
# This needs to happen while we hold the lock. We could put it on the
# exit path, but that would slow down the uncontended case.
try:
await self._clock.sleep(0)
except CancelledError:
self._release_lock(key, entry)
raise
return entry
def _release_lock(self, key: Hashable, entry: _LinearizerEntry) -> None:
"""Releases a held linearizer lock."""
logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
entry.count -= 1
if entry.deferreds:
(next_def, _) = entry.deferreds.popitem(last=False)
# we need to run the next thing in the sentinel context.
with PreserveLoggingContext():
next_def.callback(None)
elif entry.count == 0:
# We were the last thing for this key: remove it from the
# map.
del self.key_to_defer[key]
class ReadWriteLock:
"""An async read write lock.
Example:
async with read_write_lock.read("test_key"):
# do some work
"""
# IMPLEMENTATION NOTES
#
# We track the most recent queued reader and writer deferreds (which get
# resolved when they release the lock).
#
# Read: We know its safe to acquire a read lock when the latest writer has
# been resolved. The new reader is appended to the list of latest readers.
#
# Write: We know its safe to acquire the write lock when both the latest
# writers and readers have been resolved. The new writer replaces the latest
# writer.
def __init__(self) -> None:
# Latest readers queued
self.key_to_current_readers: dict[str, set[defer.Deferred]] = {}
# Latest writer queued
self.key_to_current_writer: dict[str, defer.Deferred] = {}
def read(self, key: str) -> AsyncContextManager:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
curr_readers.add(new_defer)
try:
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
# May raise a `CancelledError` if the `Deferred` wrapping us is
# cancelled. The `Deferred` we are waiting on must not be cancelled,
# since we do not own it.
if curr_writer:
await make_deferred_yieldable(stop_cancellation(curr_writer))
yield
finally:
with PreserveLoggingContext():
new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
return _ctx_manager()
def write(self, key: str) -> AsyncContextManager:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
# We wait on all latest readers and writer.
to_wait_on = list(curr_readers)
if curr_writer:
to_wait_on.append(curr_writer)
# We can clear the list of current readers since `new_defer` waits
# for them to finish.
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
to_wait_on_defer = defer.gatherResults(to_wait_on)
try:
# Wait for all current readers and the latest writer to finish.
# May raise a `CancelledError` immediately after the wait if the
# `Deferred` wrapping us is cancelled. We must only release the lock
# once we have acquired it, hence the use of `delay_cancellation`
# rather than `stop_cancellation`.
await make_deferred_yieldable(delay_cancellation(to_wait_on_defer))
yield
finally:
# Release the lock.
with PreserveLoggingContext():
new_defer.callback(None)
# `self.key_to_current_writer[key]` may be missing if there was another
# writer waiting for us and it completed entirely within the
# `new_defer.callback()` call above.
if self.key_to_current_writer.get(key) == new_defer:
self.key_to_current_writer.pop(key)
return _ctx_manager()
def timeout_deferred(
*,
deferred: "defer.Deferred[_T]",
timeout: float,
cancel_on_shutdown: bool = True,
clock: Clock,
) -> "defer.Deferred[_T]":
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new
deferred that wraps and times out the given deferred, correctly handling
the case where the given deferred's canceller throws.
(See https://twistedmatrix.com/trac/ticket/9534)
NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred.
NOTE: the TimeoutError raised by the resultant deferred is
twisted.internet.defer.TimeoutError, which is *different* to the built-in
TimeoutError, as well as various other TimeoutErrors you might have imported.
Args:
deferred: The Deferred to potentially timeout.
timeout: Timeout in seconds
cancel_on_shutdown: Whether this call should be tracked for cleanup during
shutdown. In general, all calls should be tracked. There may be a use case
not to track calls with a `timeout` of 0 (or similarly short) since tracking
them may result in rapid insertions and removals of tracked calls
unnecessarily. But unless a specific instance of tracking proves to be an
issue, we can just track all delayed calls.
clock: The `Clock` instance used to track delayed calls.
Returns:
A new Deferred, which will errback with defer.TimeoutError on timeout.
"""
new_d: "defer.Deferred[_T]" = defer.Deferred()
timed_out = [False]
def time_it_out() -> None:
timed_out[0] = True
try:
with PreserveLoggingContext():
deferred.cancel()
except Exception: # if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
# the cancel() call should have set off a chain of errbacks which
# will have errbacked new_d, but in case it hasn't, errback it now.
if not new_d.called:
with PreserveLoggingContext():
new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,)))
# We don't track these calls since they are short.
delayed_call = clock.call_later(
timeout, time_it_out, call_later_cancel_on_shutdown=cancel_on_shutdown
)
def convert_cancelled(value: Failure) -> Failure:
# if the original deferred was cancelled, and our timeout has fired, then
# the reason it was cancelled was due to our timeout. Turn the CancelledError
# into a TimeoutError.
if timed_out[0] and value.check(CancelledError):
raise defer.TimeoutError("Timed out after %gs" % (timeout,))
return value
deferred.addErrback(convert_cancelled)
def cancel_timeout(result: _T) -> _T:
# stop the pending call to cancel the deferred if it's been fired
if delayed_call.active():
delayed_call.cancel()
return result
deferred.addBoth(cancel_timeout)
def success_cb(val: _T) -> None:
if not new_d.called:
with PreserveLoggingContext():
new_d.callback(val)
def failure_cb(val: Failure) -> None:
if not new_d.called:
with PreserveLoggingContext():
new_d.errback(val)
deferred.addCallbacks(success_cb, failure_cb)
return new_d
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DoneAwaitable(Awaitable[R]):
"""Simple awaitable that returns the provided value."""
value: R
def __await__(self) -> Generator[Any, None, R]:
yield None
return self.value
def maybe_awaitable(value: Awaitable[R] | R) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable."""
if inspect.isawaitable(value):
return value
# For some reason mypy doesn't deduce that value is not Awaitable here, even though
# inspect.isawaitable returns a TypeGuard.
assert not isinstance(value, Awaitable)
return DoneAwaitable(value)
def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
"""Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`.
Args:
deferred: The `Deferred` to protect against cancellation. Must not follow the
Synapse logcontext rules.
Returns:
A new `Deferred`, which will contain the result of the original `Deferred`.
The new `Deferred` will not propagate cancellation through to the original.
When cancelled, the new `Deferred` will fail with a `CancelledError`.
The new `Deferred` will not follow the Synapse logcontext rules and should be
wrapped with `make_deferred_yieldable`.
"""
new_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(new_deferred)
return new_deferred
@overload
def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]": ...
@overload
def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]": ...
@overload
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: ...
def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
"""Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
resolve with a `CancelledError` until the original awaitable resolves.
Args:
deferred: The coroutine or `Deferred` to protect against cancellation. May
optionally follow the Synapse logcontext rules.
Returns:
A new `Deferred`, which will contain the result of the original coroutine or
`Deferred`. The new `Deferred` will not propagate cancellation through to the
original coroutine or `Deferred`.
When cancelled, the new `Deferred` will wait until the original coroutine or
`Deferred` resolves before failing with a `CancelledError`.
The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
wrapped with `make_deferred_yieldable`.
"""
# First, convert the awaitable into a `Deferred`.
if isinstance(awaitable, defer.Deferred):
deferred = awaitable
elif asyncio.iscoroutine(awaitable):
# Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
# type-checking, but we'd need Twisted >= 21.2.
deferred = defer.ensureDeferred(awaitable)
else:
# We have no idea what to do with this awaitable.
# We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
# `make_awaitable`, and let the caller `await` it normally.
return awaitable
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
# before the new deferred is cancelled, we `pause` it to stop the cancellation
# propagating. we then `unpause` it once the wrapped deferred completes, to
# propagate the exception.
new_deferred.pause()
with PreserveLoggingContext():
new_deferred.errback(Failure(CancelledError()))
deferred.addBoth(lambda _: new_deferred.unpause())
new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
deferred.chainDeferred(new_deferred)
return new_deferred
class AwakenableSleeper:
"""Allows explicitly waking up deferreds related to an entity that are
currently sleeping.
"""
def __init__(self, clock: Clock) -> None:
self._streams: dict[str, set[defer.Deferred[None]]] = {}
self._clock = clock
def wake(self, name: str) -> None:
"""Wake everything related to `name` that is currently sleeping."""
stream_set = self._streams.pop(name, set())
for deferred in stream_set:
try:
with PreserveLoggingContext():
deferred.callback(None)
except Exception:
pass
async def sleep(self, name: str, delay_ms: int) -> None:
"""Sleep for the given number of milliseconds, or return if the given
`name` is explicitly woken up.
"""
# Create a deferred that will get called if `wake` is called with
# the same `name`.
stream_set = self._streams.setdefault(name, set())
notify_deferred: "defer.Deferred[None]" = defer.Deferred()
stream_set.add(notify_deferred)
try:
# Wait for either the delay or for `wake` to be called.
await make_deferred_yieldable(
timeout_deferred(
deferred=stop_cancellation(notify_deferred),
timeout=delay_ms / 1000,
clock=self._clock,
)
)
except defer.TimeoutError:
pass
finally:
# Clean up the state
curr_stream_set = self._streams.get(name)
if curr_stream_set is not None:
curr_stream_set.discard(notify_deferred)
if len(curr_stream_set) == 0:
self._streams.pop(name)
class DeferredEvent:
"""Like threading.Event but for async code"""
def __init__(self, clock: Clock) -> None:
self._clock = clock
self._deferred: "defer.Deferred[None]" = defer.Deferred()
def set(self) -> None:
if not self._deferred.called:
with PreserveLoggingContext():
self._deferred.callback(None)
def clear(self) -> None:
if self._deferred.called:
self._deferred = defer.Deferred()
def is_set(self) -> bool:
return self._deferred.called
async def wait(self, timeout_seconds: float) -> bool:
if self.is_set():
return True
try:
await make_deferred_yieldable(
timeout_deferred(
deferred=stop_cancellation(self._deferred),
timeout=timeout_seconds,
clock=self._clock,
)
)
except defer.TimeoutError:
pass
return self.is_set()
|