simple_pub_sub.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from asyncio import Future, Queue, ensure_future, sleep
  2. from inspect import isawaitable
  3. from typing import Any, AsyncIterator, Callable, Optional, Set
  4. try:
  5. from asyncio import get_running_loop
  6. except ImportError:
  7. from asyncio import get_event_loop as get_running_loop # Python < 3.7
  8. __all__ = ["SimplePubSub", "SimplePubSubIterator"]
  9. class SimplePubSub:
  10. """A very simple publish-subscript system.
  11. Creates an AsyncIterator from an EventEmitter.
  12. Useful for mocking a PubSub system for tests.
  13. """
  14. subscribers: Set[Callable]
  15. def __init__(self) -> None:
  16. self.subscribers = set()
  17. def emit(self, event: Any) -> bool:
  18. """Emit an event."""
  19. for subscriber in self.subscribers:
  20. result = subscriber(event)
  21. if isawaitable(result):
  22. ensure_future(result)
  23. return bool(self.subscribers)
  24. def get_subscriber(
  25. self, transform: Optional[Callable] = None
  26. ) -> "SimplePubSubIterator":
  27. return SimplePubSubIterator(self, transform)
  28. class SimplePubSubIterator(AsyncIterator):
  29. def __init__(self, pubsub: SimplePubSub, transform: Optional[Callable]) -> None:
  30. self.pubsub = pubsub
  31. self.transform = transform
  32. self.pull_queue: Queue[Future] = Queue()
  33. self.push_queue: Queue[Any] = Queue()
  34. self.listening = True
  35. pubsub.subscribers.add(self.push_value)
  36. def __aiter__(self) -> "SimplePubSubIterator":
  37. return self
  38. async def __anext__(self) -> Any:
  39. if not self.listening:
  40. raise StopAsyncIteration
  41. await sleep(0)
  42. if not self.push_queue.empty():
  43. return await self.push_queue.get()
  44. future = get_running_loop().create_future()
  45. await self.pull_queue.put(future)
  46. return future
  47. async def aclose(self) -> None:
  48. if self.listening:
  49. await self.empty_queue()
  50. async def empty_queue(self) -> None:
  51. self.listening = False
  52. self.pubsub.subscribers.remove(self.push_value)
  53. while not self.pull_queue.empty():
  54. future = await self.pull_queue.get()
  55. future.cancel()
  56. while not self.push_queue.empty():
  57. await self.push_queue.get()
  58. async def push_value(self, event: Any) -> None:
  59. value = event if self.transform is None else self.transform(event)
  60. if self.pull_queue.empty():
  61. await self.push_queue.put(value)
  62. else:
  63. (await self.pull_queue.get()).set_result(value)