task.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. ##############################################################################
  2. #
  3. # Copyright (c) 2001, 2002 Zope Foundation and Contributors.
  4. # All Rights Reserved.
  5. #
  6. # This software is subject to the provisions of the Zope Public License,
  7. # Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
  8. # THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
  9. # WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  10. # WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
  11. # FOR A PARTICULAR PURPOSE.
  12. #
  13. ##############################################################################
  14. from collections import deque
  15. import socket
  16. import sys
  17. import threading
  18. import time
  19. from .buffers import ReadOnlyFileBasedBuffer
  20. from .utilities import build_http_date, logger, queue_logger
  21. rename_headers = { # or keep them without the HTTP_ prefix added
  22. "CONTENT_LENGTH": "CONTENT_LENGTH",
  23. "CONTENT_TYPE": "CONTENT_TYPE",
  24. }
  25. hop_by_hop = frozenset(
  26. (
  27. "connection",
  28. "keep-alive",
  29. "proxy-authenticate",
  30. "proxy-authorization",
  31. "te",
  32. "trailers",
  33. "transfer-encoding",
  34. "upgrade",
  35. )
  36. )
  37. class ThreadedTaskDispatcher:
  38. """A Task Dispatcher that creates a thread for each task."""
  39. stop_count = 0 # Number of threads that will stop soon.
  40. active_count = 0 # Number of currently active threads
  41. logger = logger
  42. queue_logger = queue_logger
  43. def __init__(self):
  44. self.threads = set()
  45. self.queue = deque()
  46. self.lock = threading.Lock()
  47. self.queue_cv = threading.Condition(self.lock)
  48. self.thread_exit_cv = threading.Condition(self.lock)
  49. def start_new_thread(self, target, thread_no):
  50. t = threading.Thread(
  51. target=target, name=f"waitress-{thread_no}", args=(thread_no,)
  52. )
  53. t.daemon = True
  54. t.start()
  55. def handler_thread(self, thread_no):
  56. while True:
  57. with self.lock:
  58. while not self.queue and self.stop_count == 0:
  59. # Mark ourselves as idle before waiting to be
  60. # woken up, then we will once again be active
  61. self.active_count -= 1
  62. self.queue_cv.wait()
  63. self.active_count += 1
  64. if self.stop_count > 0:
  65. self.active_count -= 1
  66. self.stop_count -= 1
  67. self.threads.discard(thread_no)
  68. self.thread_exit_cv.notify()
  69. break
  70. task = self.queue.popleft()
  71. try:
  72. task.service()
  73. except BaseException:
  74. self.logger.exception("Exception when servicing %r", task)
  75. def set_thread_count(self, count):
  76. with self.lock:
  77. threads = self.threads
  78. thread_no = 0
  79. running = len(threads) - self.stop_count
  80. while running < count:
  81. # Start threads.
  82. while thread_no in threads:
  83. thread_no = thread_no + 1
  84. threads.add(thread_no)
  85. running += 1
  86. self.start_new_thread(self.handler_thread, thread_no)
  87. self.active_count += 1
  88. thread_no = thread_no + 1
  89. if running > count:
  90. # Stop threads.
  91. self.stop_count += running - count
  92. self.queue_cv.notify_all()
  93. def add_task(self, task):
  94. with self.lock:
  95. self.queue.append(task)
  96. self.queue_cv.notify()
  97. queue_size = len(self.queue)
  98. idle_threads = len(self.threads) - self.stop_count - self.active_count
  99. if queue_size > idle_threads:
  100. self.queue_logger.warning(
  101. "Task queue depth is %d", queue_size - idle_threads
  102. )
  103. def shutdown(self, cancel_pending=True, timeout=5):
  104. self.set_thread_count(0)
  105. # Ensure the threads shut down.
  106. threads = self.threads
  107. expiration = time.time() + timeout
  108. with self.lock:
  109. while threads:
  110. if time.time() >= expiration:
  111. self.logger.warning("%d thread(s) still running", len(threads))
  112. break
  113. self.thread_exit_cv.wait(0.1)
  114. if cancel_pending:
  115. # Cancel remaining tasks.
  116. queue = self.queue
  117. if len(queue) > 0:
  118. self.logger.warning("Canceling %d pending task(s)", len(queue))
  119. while queue:
  120. task = queue.popleft()
  121. task.cancel()
  122. self.queue_cv.notify_all()
  123. return True
  124. return False
  125. class Task:
  126. close_on_finish = False
  127. status = "200 OK"
  128. wrote_header = False
  129. start_time = 0
  130. content_length = None
  131. content_bytes_written = 0
  132. logged_write_excess = False
  133. logged_write_no_body = False
  134. complete = False
  135. chunked_response = False
  136. logger = logger
  137. def __init__(self, channel, request):
  138. self.channel = channel
  139. self.request = request
  140. self.response_headers = []
  141. version = request.version
  142. if version not in ("1.0", "1.1"):
  143. # fall back to a version we support.
  144. version = "1.0"
  145. self.version = version
  146. def service(self):
  147. try:
  148. self.start()
  149. self.execute()
  150. self.finish()
  151. except OSError:
  152. self.close_on_finish = True
  153. if self.channel.adj.log_socket_errors:
  154. raise
  155. @property
  156. def has_body(self):
  157. return not (
  158. self.status.startswith("1")
  159. or self.status.startswith("204")
  160. or self.status.startswith("304")
  161. )
  162. def build_response_header(self):
  163. version = self.version
  164. # Figure out whether the connection should be closed.
  165. connection = self.request.headers.get("CONNECTION", "").lower()
  166. response_headers = []
  167. content_length_header = None
  168. date_header = None
  169. server_header = None
  170. connection_close_header = None
  171. for (headername, headerval) in self.response_headers:
  172. headername = "-".join([x.capitalize() for x in headername.split("-")])
  173. if headername == "Content-Length":
  174. if self.has_body:
  175. content_length_header = headerval
  176. else:
  177. continue # pragma: no cover
  178. if headername == "Date":
  179. date_header = headerval
  180. if headername == "Server":
  181. server_header = headerval
  182. if headername == "Connection":
  183. connection_close_header = headerval.lower()
  184. # replace with properly capitalized version
  185. response_headers.append((headername, headerval))
  186. if (
  187. content_length_header is None
  188. and self.content_length is not None
  189. and self.has_body
  190. ):
  191. content_length_header = str(self.content_length)
  192. response_headers.append(("Content-Length", content_length_header))
  193. def close_on_finish():
  194. if connection_close_header is None:
  195. response_headers.append(("Connection", "close"))
  196. self.close_on_finish = True
  197. if version == "1.0":
  198. if connection == "keep-alive":
  199. if not content_length_header:
  200. close_on_finish()
  201. else:
  202. response_headers.append(("Connection", "Keep-Alive"))
  203. else:
  204. close_on_finish()
  205. elif version == "1.1":
  206. if connection == "close":
  207. close_on_finish()
  208. if not content_length_header:
  209. # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
  210. # for any response with a status code of 1xx, 204 or 304.
  211. if self.has_body:
  212. response_headers.append(("Transfer-Encoding", "chunked"))
  213. self.chunked_response = True
  214. if not self.close_on_finish:
  215. close_on_finish()
  216. # under HTTP 1.1 keep-alive is default, no need to set the header
  217. else:
  218. raise AssertionError("neither HTTP/1.0 or HTTP/1.1")
  219. # Set the Server and Date field, if not yet specified. This is needed
  220. # if the server is used as a proxy.
  221. ident = self.channel.server.adj.ident
  222. if not server_header:
  223. if ident:
  224. response_headers.append(("Server", ident))
  225. else:
  226. response_headers.append(("Via", ident or "waitress"))
  227. if not date_header:
  228. response_headers.append(("Date", build_http_date(self.start_time)))
  229. self.response_headers = response_headers
  230. first_line = f"HTTP/{self.version} {self.status}"
  231. # NB: sorting headers needs to preserve same-named-header order
  232. # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here;
  233. # rely on stable sort to keep relative position of same-named headers
  234. next_lines = [
  235. "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0])
  236. ]
  237. lines = [first_line] + next_lines
  238. res = "%s\r\n\r\n" % "\r\n".join(lines)
  239. return res.encode("latin-1")
  240. def remove_content_length_header(self):
  241. response_headers = []
  242. for header_name, header_value in self.response_headers:
  243. if header_name.lower() == "content-length":
  244. continue # pragma: nocover
  245. response_headers.append((header_name, header_value))
  246. self.response_headers = response_headers
  247. def start(self):
  248. self.start_time = time.time()
  249. def finish(self):
  250. if not self.wrote_header:
  251. self.write(b"")
  252. if self.chunked_response:
  253. # not self.write, it will chunk it!
  254. self.channel.write_soon(b"0\r\n\r\n")
  255. def write(self, data):
  256. if not self.complete:
  257. raise RuntimeError("start_response was not called before body written")
  258. channel = self.channel
  259. if not self.wrote_header:
  260. rh = self.build_response_header()
  261. channel.write_soon(rh)
  262. self.wrote_header = True
  263. if data and self.has_body:
  264. towrite = data
  265. cl = self.content_length
  266. if self.chunked_response:
  267. # use chunked encoding response
  268. towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n"
  269. towrite += data + b"\r\n"
  270. elif cl is not None:
  271. towrite = data[: cl - self.content_bytes_written]
  272. self.content_bytes_written += len(towrite)
  273. if towrite != data and not self.logged_write_excess:
  274. self.logger.warning(
  275. "application-written content exceeded the number of "
  276. "bytes specified by Content-Length header (%s)" % cl
  277. )
  278. self.logged_write_excess = True
  279. if towrite:
  280. channel.write_soon(towrite)
  281. elif data:
  282. # Cheat, and tell the application we have written all of the bytes,
  283. # even though the response shouldn't have a body and we are
  284. # ignoring it entirely.
  285. self.content_bytes_written += len(data)
  286. if not self.logged_write_no_body:
  287. self.logger.warning(
  288. "application-written content was ignored due to HTTP "
  289. "response that may not contain a message-body: (%s)" % self.status
  290. )
  291. self.logged_write_no_body = True
  292. class ErrorTask(Task):
  293. """An error task produces an error response"""
  294. complete = True
  295. def execute(self):
  296. e = self.request.error
  297. status, headers, body = e.to_response()
  298. self.status = status
  299. self.response_headers.extend(headers)
  300. # We need to explicitly tell the remote client we are closing the
  301. # connection, because self.close_on_finish is set, and we are going to
  302. # slam the door in the clients face.
  303. self.response_headers.append(("Connection", "close"))
  304. self.close_on_finish = True
  305. self.content_length = len(body)
  306. self.write(body)
  307. class WSGITask(Task):
  308. """A WSGI task produces a response from a WSGI application."""
  309. environ = None
  310. def execute(self):
  311. environ = self.get_environment()
  312. def start_response(status, headers, exc_info=None):
  313. if self.complete and not exc_info:
  314. raise AssertionError(
  315. "start_response called a second time without providing exc_info."
  316. )
  317. if exc_info:
  318. try:
  319. if self.wrote_header:
  320. # higher levels will catch and handle raised exception:
  321. # 1. "service" method in task.py
  322. # 2. "service" method in channel.py
  323. # 3. "handler_thread" method in task.py
  324. raise exc_info[1]
  325. else:
  326. # As per WSGI spec existing headers must be cleared
  327. self.response_headers = []
  328. finally:
  329. exc_info = None
  330. self.complete = True
  331. if not status.__class__ is str:
  332. raise AssertionError("status %s is not a string" % status)
  333. if "\n" in status or "\r" in status:
  334. raise ValueError(
  335. "carriage return/line feed character present in status"
  336. )
  337. self.status = status
  338. # Prepare the headers for output
  339. for k, v in headers:
  340. if not k.__class__ is str:
  341. raise AssertionError(
  342. f"Header name {k!r} is not a string in {(k, v)!r}"
  343. )
  344. if not v.__class__ is str:
  345. raise AssertionError(
  346. f"Header value {v!r} is not a string in {(k, v)!r}"
  347. )
  348. if "\n" in v or "\r" in v:
  349. raise ValueError(
  350. "carriage return/line feed character present in header value"
  351. )
  352. if "\n" in k or "\r" in k:
  353. raise ValueError(
  354. "carriage return/line feed character present in header name"
  355. )
  356. kl = k.lower()
  357. if kl == "content-length":
  358. self.content_length = int(v)
  359. elif kl in hop_by_hop:
  360. raise AssertionError(
  361. '%s is a "hop-by-hop" header; it cannot be used by '
  362. "a WSGI application (see PEP 3333)" % k
  363. )
  364. self.response_headers.extend(headers)
  365. # Return a method used to write the response data.
  366. return self.write
  367. # Call the application to handle the request and write a response
  368. app_iter = self.channel.server.application(environ, start_response)
  369. can_close_app_iter = True
  370. try:
  371. if app_iter.__class__ is ReadOnlyFileBasedBuffer:
  372. cl = self.content_length
  373. size = app_iter.prepare(cl)
  374. if size:
  375. if cl != size:
  376. if cl is not None:
  377. self.remove_content_length_header()
  378. self.content_length = size
  379. self.write(b"") # generate headers
  380. # if the write_soon below succeeds then the channel will
  381. # take over closing the underlying file via the channel's
  382. # _flush_some or handle_close so we intentionally avoid
  383. # calling close in the finally block
  384. self.channel.write_soon(app_iter)
  385. can_close_app_iter = False
  386. return
  387. first_chunk_len = None
  388. for chunk in app_iter:
  389. if first_chunk_len is None:
  390. first_chunk_len = len(chunk)
  391. # Set a Content-Length header if one is not supplied.
  392. # start_response may not have been called until first
  393. # iteration as per PEP, so we must reinterrogate
  394. # self.content_length here
  395. if self.content_length is None:
  396. app_iter_len = None
  397. if hasattr(app_iter, "__len__"):
  398. app_iter_len = len(app_iter)
  399. if app_iter_len == 1:
  400. self.content_length = first_chunk_len
  401. # transmit headers only after first iteration of the iterable
  402. # that returns a non-empty bytestring (PEP 3333)
  403. if chunk:
  404. self.write(chunk)
  405. cl = self.content_length
  406. if cl is not None:
  407. if self.content_bytes_written != cl:
  408. # close the connection so the client isn't sitting around
  409. # waiting for more data when there are too few bytes
  410. # to service content-length
  411. self.close_on_finish = True
  412. if self.request.command != "HEAD":
  413. self.logger.warning(
  414. "application returned too few bytes (%s) "
  415. "for specified Content-Length (%s) via app_iter"
  416. % (self.content_bytes_written, cl),
  417. )
  418. finally:
  419. if can_close_app_iter and hasattr(app_iter, "close"):
  420. app_iter.close()
  421. def get_environment(self):
  422. """Returns a WSGI environment."""
  423. environ = self.environ
  424. if environ is not None:
  425. # Return the cached copy.
  426. return environ
  427. request = self.request
  428. path = request.path
  429. channel = self.channel
  430. server = channel.server
  431. url_prefix = server.adj.url_prefix
  432. if path.startswith("/"):
  433. # strip extra slashes at the beginning of a path that starts
  434. # with any number of slashes
  435. path = "/" + path.lstrip("/")
  436. if url_prefix:
  437. # NB: url_prefix is guaranteed by the configuration machinery to
  438. # be either the empty string or a string that starts with a single
  439. # slash and ends without any slashes
  440. if path == url_prefix:
  441. # if the path is the same as the url prefix, the SCRIPT_NAME
  442. # should be the url_prefix and PATH_INFO should be empty
  443. path = ""
  444. else:
  445. # if the path starts with the url prefix plus a slash,
  446. # the SCRIPT_NAME should be the url_prefix and PATH_INFO should
  447. # the value of path from the slash until its end
  448. url_prefix_with_trailing_slash = url_prefix + "/"
  449. if path.startswith(url_prefix_with_trailing_slash):
  450. path = path[len(url_prefix) :]
  451. environ = {
  452. "REMOTE_ADDR": channel.addr[0],
  453. # Nah, we aren't actually going to look up the reverse DNS for
  454. # REMOTE_ADDR, but we will happily set this environment variable
  455. # for the WSGI application. Spec says we can just set this to
  456. # REMOTE_ADDR, so we do.
  457. "REMOTE_HOST": channel.addr[0],
  458. # try and set the REMOTE_PORT to something useful, but maybe None
  459. "REMOTE_PORT": str(channel.addr[1]),
  460. "REQUEST_METHOD": request.command.upper(),
  461. "SERVER_PORT": str(server.effective_port),
  462. "SERVER_NAME": server.server_name,
  463. "SERVER_SOFTWARE": server.adj.ident,
  464. "SERVER_PROTOCOL": "HTTP/%s" % self.version,
  465. "SCRIPT_NAME": url_prefix,
  466. "PATH_INFO": path,
  467. "REQUEST_URI": request.request_uri,
  468. "QUERY_STRING": request.query,
  469. "wsgi.url_scheme": request.url_scheme,
  470. # the following environment variables are required by the WSGI spec
  471. "wsgi.version": (1, 0),
  472. # apps should use the logging module
  473. "wsgi.errors": sys.stderr,
  474. "wsgi.multithread": True,
  475. "wsgi.multiprocess": False,
  476. "wsgi.run_once": False,
  477. "wsgi.input": request.get_body_stream(),
  478. "wsgi.file_wrapper": ReadOnlyFileBasedBuffer,
  479. "wsgi.input_terminated": True, # wsgi.input is EOF terminated
  480. }
  481. for key, value in dict(request.headers).items():
  482. value = value.strip()
  483. mykey = rename_headers.get(key, None)
  484. if mykey is None:
  485. mykey = "HTTP_" + key
  486. if mykey not in environ:
  487. environ[mykey] = value
  488. # Insert a callable into the environment that allows the application to
  489. # check if the client disconnected. Only works with
  490. # channel_request_lookahead larger than 0.
  491. environ["waitress.client_disconnected"] = self.channel.check_client_disconnected
  492. # cache the environ for this request
  493. self.environ = environ
  494. return environ