streams.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. """Stream-related things."""
  2. __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
  3. 'open_connection', 'start_server',
  4. 'IncompleteReadError',
  5. ]
  6. import socket
  7. if hasattr(socket, 'AF_UNIX'):
  8. __all__.extend(['open_unix_connection', 'start_unix_server'])
  9. from . import coroutines
  10. from . import events
  11. from . import futures
  12. from . import protocols
  13. from .coroutines import coroutine
  14. from .log import logger
  15. _DEFAULT_LIMIT = 2**16
  16. class IncompleteReadError(EOFError):
  17. """
  18. Incomplete read error. Attributes:
  19. - partial: read bytes string before the end of stream was reached
  20. - expected: total number of expected bytes
  21. """
  22. def __init__(self, partial, expected):
  23. EOFError.__init__(self, "%s bytes read on a total of %s expected bytes"
  24. % (len(partial), expected))
  25. self.partial = partial
  26. self.expected = expected
  27. @coroutine
  28. def open_connection(host=None, port=None, *,
  29. loop=None, limit=_DEFAULT_LIMIT, **kwds):
  30. """A wrapper for create_connection() returning a (reader, writer) pair.
  31. The reader returned is a StreamReader instance; the writer is a
  32. StreamWriter instance.
  33. The arguments are all the usual arguments to create_connection()
  34. except protocol_factory; most common are positional host and port,
  35. with various optional keyword arguments following.
  36. Additional optional keyword arguments are loop (to set the event loop
  37. instance to use) and limit (to set the buffer limit passed to the
  38. StreamReader).
  39. (If you want to customize the StreamReader and/or
  40. StreamReaderProtocol classes, just copy the code -- there's
  41. really nothing special here except some convenience.)
  42. """
  43. if loop is None:
  44. loop = events.get_event_loop()
  45. reader = StreamReader(limit=limit, loop=loop)
  46. protocol = StreamReaderProtocol(reader, loop=loop)
  47. transport, _ = yield from loop.create_connection(
  48. lambda: protocol, host, port, **kwds)
  49. writer = StreamWriter(transport, protocol, reader, loop)
  50. return reader, writer
  51. @coroutine
  52. def start_server(client_connected_cb, host=None, port=None, *,
  53. loop=None, limit=_DEFAULT_LIMIT, **kwds):
  54. """Start a socket server, call back for each client connected.
  55. The first parameter, `client_connected_cb`, takes two parameters:
  56. client_reader, client_writer. client_reader is a StreamReader
  57. object, while client_writer is a StreamWriter object. This
  58. parameter can either be a plain callback function or a coroutine;
  59. if it is a coroutine, it will be automatically converted into a
  60. Task.
  61. The rest of the arguments are all the usual arguments to
  62. loop.create_server() except protocol_factory; most common are
  63. positional host and port, with various optional keyword arguments
  64. following. The return value is the same as loop.create_server().
  65. Additional optional keyword arguments are loop (to set the event loop
  66. instance to use) and limit (to set the buffer limit passed to the
  67. StreamReader).
  68. The return value is the same as loop.create_server(), i.e. a
  69. Server object which can be used to stop the service.
  70. """
  71. if loop is None:
  72. loop = events.get_event_loop()
  73. def factory():
  74. reader = StreamReader(limit=limit, loop=loop)
  75. protocol = StreamReaderProtocol(reader, client_connected_cb,
  76. loop=loop)
  77. return protocol
  78. return (yield from loop.create_server(factory, host, port, **kwds))
  79. if hasattr(socket, 'AF_UNIX'):
  80. # UNIX Domain Sockets are supported on this platform
  81. @coroutine
  82. def open_unix_connection(path=None, *,
  83. loop=None, limit=_DEFAULT_LIMIT, **kwds):
  84. """Similar to `open_connection` but works with UNIX Domain Sockets."""
  85. if loop is None:
  86. loop = events.get_event_loop()
  87. reader = StreamReader(limit=limit, loop=loop)
  88. protocol = StreamReaderProtocol(reader, loop=loop)
  89. transport, _ = yield from loop.create_unix_connection(
  90. lambda: protocol, path, **kwds)
  91. writer = StreamWriter(transport, protocol, reader, loop)
  92. return reader, writer
  93. @coroutine
  94. def start_unix_server(client_connected_cb, path=None, *,
  95. loop=None, limit=_DEFAULT_LIMIT, **kwds):
  96. """Similar to `start_server` but works with UNIX Domain Sockets."""
  97. if loop is None:
  98. loop = events.get_event_loop()
  99. def factory():
  100. reader = StreamReader(limit=limit, loop=loop)
  101. protocol = StreamReaderProtocol(reader, client_connected_cb,
  102. loop=loop)
  103. return protocol
  104. return (yield from loop.create_unix_server(factory, path, **kwds))
  105. class FlowControlMixin(protocols.Protocol):
  106. """Reusable flow control logic for StreamWriter.drain().
  107. This implements the protocol methods pause_writing(),
  108. resume_reading() and connection_lost(). If the subclass overrides
  109. these it must call the super methods.
  110. StreamWriter.drain() must wait for _drain_helper() coroutine.
  111. """
  112. def __init__(self, loop=None):
  113. if loop is None:
  114. self._loop = events.get_event_loop()
  115. else:
  116. self._loop = loop
  117. self._paused = False
  118. self._drain_waiter = None
  119. self._connection_lost = False
  120. def pause_writing(self):
  121. assert not self._paused
  122. self._paused = True
  123. if self._loop.get_debug():
  124. logger.debug("%r pauses writing", self)
  125. def resume_writing(self):
  126. assert self._paused
  127. self._paused = False
  128. if self._loop.get_debug():
  129. logger.debug("%r resumes writing", self)
  130. waiter = self._drain_waiter
  131. if waiter is not None:
  132. self._drain_waiter = None
  133. if not waiter.done():
  134. waiter.set_result(None)
  135. def connection_lost(self, exc):
  136. self._connection_lost = True
  137. # Wake up the writer if currently paused.
  138. if not self._paused:
  139. return
  140. waiter = self._drain_waiter
  141. if waiter is None:
  142. return
  143. self._drain_waiter = None
  144. if waiter.done():
  145. return
  146. if exc is None:
  147. waiter.set_result(None)
  148. else:
  149. waiter.set_exception(exc)
  150. @coroutine
  151. def _drain_helper(self):
  152. if self._connection_lost:
  153. raise ConnectionResetError('Connection lost')
  154. if not self._paused:
  155. return
  156. waiter = self._drain_waiter
  157. assert waiter is None or waiter.cancelled()
  158. waiter = futures.Future(loop=self._loop)
  159. self._drain_waiter = waiter
  160. yield from waiter
  161. class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
  162. """Helper class to adapt between Protocol and StreamReader.
  163. (This is a helper class instead of making StreamReader itself a
  164. Protocol subclass, because the StreamReader has other potential
  165. uses, and to prevent the user of the StreamReader to accidentally
  166. call inappropriate methods of the protocol.)
  167. """
  168. def __init__(self, stream_reader, client_connected_cb=None, loop=None):
  169. super().__init__(loop=loop)
  170. self._stream_reader = stream_reader
  171. self._stream_writer = None
  172. self._client_connected_cb = client_connected_cb
  173. def connection_made(self, transport):
  174. self._stream_reader.set_transport(transport)
  175. if self._client_connected_cb is not None:
  176. self._stream_writer = StreamWriter(transport, self,
  177. self._stream_reader,
  178. self._loop)
  179. res = self._client_connected_cb(self._stream_reader,
  180. self._stream_writer)
  181. if coroutines.iscoroutine(res):
  182. self._loop.create_task(res)
  183. def connection_lost(self, exc):
  184. if exc is None:
  185. self._stream_reader.feed_eof()
  186. else:
  187. self._stream_reader.set_exception(exc)
  188. super().connection_lost(exc)
  189. def data_received(self, data):
  190. self._stream_reader.feed_data(data)
  191. def eof_received(self):
  192. self._stream_reader.feed_eof()
  193. class StreamWriter:
  194. """Wraps a Transport.
  195. This exposes write(), writelines(), [can_]write_eof(),
  196. get_extra_info() and close(). It adds drain() which returns an
  197. optional Future on which you can wait for flow control. It also
  198. adds a transport property which references the Transport
  199. directly.
  200. """
  201. def __init__(self, transport, protocol, reader, loop):
  202. self._transport = transport
  203. self._protocol = protocol
  204. # drain() expects that the reader has a exception() method
  205. assert reader is None or isinstance(reader, StreamReader)
  206. self._reader = reader
  207. self._loop = loop
  208. def __repr__(self):
  209. info = [self.__class__.__name__, 'transport=%r' % self._transport]
  210. if self._reader is not None:
  211. info.append('reader=%r' % self._reader)
  212. return '<%s>' % ' '.join(info)
  213. @property
  214. def transport(self):
  215. return self._transport
  216. def write(self, data):
  217. self._transport.write(data)
  218. def writelines(self, data):
  219. self._transport.writelines(data)
  220. def write_eof(self):
  221. return self._transport.write_eof()
  222. def can_write_eof(self):
  223. return self._transport.can_write_eof()
  224. def close(self):
  225. return self._transport.close()
  226. def get_extra_info(self, name, default=None):
  227. return self._transport.get_extra_info(name, default)
  228. @coroutine
  229. def drain(self):
  230. """Flush the write buffer.
  231. The intended use is to write
  232. w.write(data)
  233. yield from w.drain()
  234. """
  235. if self._reader is not None:
  236. exc = self._reader.exception()
  237. if exc is not None:
  238. raise exc
  239. yield from self._protocol._drain_helper()
  240. class StreamReader:
  241. def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
  242. # The line length limit is a security feature;
  243. # it also doubles as half the buffer limit.
  244. self._limit = limit
  245. if loop is None:
  246. self._loop = events.get_event_loop()
  247. else:
  248. self._loop = loop
  249. self._buffer = bytearray()
  250. self._eof = False # Whether we're done.
  251. self._waiter = None # A future used by _wait_for_data()
  252. self._exception = None
  253. self._transport = None
  254. self._paused = False
  255. def exception(self):
  256. return self._exception
  257. def set_exception(self, exc):
  258. self._exception = exc
  259. waiter = self._waiter
  260. if waiter is not None:
  261. self._waiter = None
  262. if not waiter.cancelled():
  263. waiter.set_exception(exc)
  264. def _wakeup_waiter(self):
  265. """Wakeup read() or readline() function waiting for data or EOF."""
  266. waiter = self._waiter
  267. if waiter is not None:
  268. self._waiter = None
  269. if not waiter.cancelled():
  270. waiter.set_result(None)
  271. def set_transport(self, transport):
  272. assert self._transport is None, 'Transport already set'
  273. self._transport = transport
  274. def _maybe_resume_transport(self):
  275. if self._paused and len(self._buffer) <= self._limit:
  276. self._paused = False
  277. self._transport.resume_reading()
  278. def feed_eof(self):
  279. self._eof = True
  280. self._wakeup_waiter()
  281. def at_eof(self):
  282. """Return True if the buffer is empty and 'feed_eof' was called."""
  283. return self._eof and not self._buffer
  284. def feed_data(self, data):
  285. assert not self._eof, 'feed_data after feed_eof'
  286. if not data:
  287. return
  288. self._buffer.extend(data)
  289. self._wakeup_waiter()
  290. if (self._transport is not None and
  291. not self._paused and
  292. len(self._buffer) > 2*self._limit):
  293. try:
  294. self._transport.pause_reading()
  295. except NotImplementedError:
  296. # The transport can't be paused.
  297. # We'll just have to buffer all data.
  298. # Forget the transport so we don't keep trying.
  299. self._transport = None
  300. else:
  301. self._paused = True
  302. def _wait_for_data(self, func_name):
  303. """Wait until feed_data() or feed_eof() is called."""
  304. # StreamReader uses a future to link the protocol feed_data() method
  305. # to a read coroutine. Running two read coroutines at the same time
  306. # would have an unexpected behaviour. It would not possible to know
  307. # which coroutine would get the next data.
  308. if self._waiter is not None:
  309. raise RuntimeError('%s() called while another coroutine is '
  310. 'already waiting for incoming data' % func_name)
  311. self._waiter = futures.Future(loop=self._loop)
  312. try:
  313. yield from self._waiter
  314. finally:
  315. self._waiter = None
  316. @coroutine
  317. def readline(self):
  318. if self._exception is not None:
  319. raise self._exception
  320. line = bytearray()
  321. not_enough = True
  322. while not_enough:
  323. while self._buffer and not_enough:
  324. ichar = self._buffer.find(b'\n')
  325. if ichar < 0:
  326. line.extend(self._buffer)
  327. self._buffer.clear()
  328. else:
  329. ichar += 1
  330. line.extend(self._buffer[:ichar])
  331. del self._buffer[:ichar]
  332. not_enough = False
  333. if len(line) > self._limit:
  334. self._maybe_resume_transport()
  335. raise ValueError('Line is too long')
  336. if self._eof:
  337. break
  338. if not_enough:
  339. yield from self._wait_for_data('readline')
  340. self._maybe_resume_transport()
  341. return bytes(line)
  342. @coroutine
  343. def read(self, n=-1):
  344. if self._exception is not None:
  345. raise self._exception
  346. if not n:
  347. return b''
  348. if n < 0:
  349. # This used to just loop creating a new waiter hoping to
  350. # collect everything in self._buffer, but that would
  351. # deadlock if the subprocess sends more than self.limit
  352. # bytes. So just call self.read(self._limit) until EOF.
  353. blocks = []
  354. while True:
  355. block = yield from self.read(self._limit)
  356. if not block:
  357. break
  358. blocks.append(block)
  359. return b''.join(blocks)
  360. else:
  361. if not self._buffer and not self._eof:
  362. yield from self._wait_for_data('read')
  363. if n < 0 or len(self._buffer) <= n:
  364. data = bytes(self._buffer)
  365. self._buffer.clear()
  366. else:
  367. # n > 0 and len(self._buffer) > n
  368. data = bytes(self._buffer[:n])
  369. del self._buffer[:n]
  370. self._maybe_resume_transport()
  371. return data
  372. @coroutine
  373. def readexactly(self, n):
  374. if self._exception is not None:
  375. raise self._exception
  376. # There used to be "optimized" code here. It created its own
  377. # Future and waited until self._buffer had at least the n
  378. # bytes, then called read(n). Unfortunately, this could pause
  379. # the transport if the argument was larger than the pause
  380. # limit (which is twice self._limit). So now we just read()
  381. # into a local buffer.
  382. blocks = []
  383. while n > 0:
  384. block = yield from self.read(n)
  385. if not block:
  386. partial = b''.join(blocks)
  387. raise IncompleteReadError(partial, len(partial) + n)
  388. blocks.append(block)
  389. n -= len(block)
  390. return b''.join(blocks)