123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571 |
- ##############################################################################
- #
- # Copyright (c) 2001, 2002 Zope Foundation and Contributors.
- # All Rights Reserved.
- #
- # This software is subject to the provisions of the Zope Public License,
- # Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
- # THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
- # WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
- # WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
- # FOR A PARTICULAR PURPOSE.
- #
- ##############################################################################
- from collections import deque
- import socket
- import sys
- import threading
- import time
- from .buffers import ReadOnlyFileBasedBuffer
- from .utilities import build_http_date, logger, queue_logger
- rename_headers = { # or keep them without the HTTP_ prefix added
- "CONTENT_LENGTH": "CONTENT_LENGTH",
- "CONTENT_TYPE": "CONTENT_TYPE",
- }
- hop_by_hop = frozenset(
- (
- "connection",
- "keep-alive",
- "proxy-authenticate",
- "proxy-authorization",
- "te",
- "trailers",
- "transfer-encoding",
- "upgrade",
- )
- )
- class ThreadedTaskDispatcher:
- """A Task Dispatcher that creates a thread for each task."""
- stop_count = 0 # Number of threads that will stop soon.
- active_count = 0 # Number of currently active threads
- logger = logger
- queue_logger = queue_logger
- def __init__(self):
- self.threads = set()
- self.queue = deque()
- self.lock = threading.Lock()
- self.queue_cv = threading.Condition(self.lock)
- self.thread_exit_cv = threading.Condition(self.lock)
- def start_new_thread(self, target, thread_no):
- t = threading.Thread(
- target=target, name=f"waitress-{thread_no}", args=(thread_no,)
- )
- t.daemon = True
- t.start()
- def handler_thread(self, thread_no):
- while True:
- with self.lock:
- while not self.queue and self.stop_count == 0:
- # Mark ourselves as idle before waiting to be
- # woken up, then we will once again be active
- self.active_count -= 1
- self.queue_cv.wait()
- self.active_count += 1
- if self.stop_count > 0:
- self.active_count -= 1
- self.stop_count -= 1
- self.threads.discard(thread_no)
- self.thread_exit_cv.notify()
- break
- task = self.queue.popleft()
- try:
- task.service()
- except BaseException:
- self.logger.exception("Exception when servicing %r", task)
- def set_thread_count(self, count):
- with self.lock:
- threads = self.threads
- thread_no = 0
- running = len(threads) - self.stop_count
- while running < count:
- # Start threads.
- while thread_no in threads:
- thread_no = thread_no + 1
- threads.add(thread_no)
- running += 1
- self.start_new_thread(self.handler_thread, thread_no)
- self.active_count += 1
- thread_no = thread_no + 1
- if running > count:
- # Stop threads.
- self.stop_count += running - count
- self.queue_cv.notify_all()
- def add_task(self, task):
- with self.lock:
- self.queue.append(task)
- self.queue_cv.notify()
- queue_size = len(self.queue)
- idle_threads = len(self.threads) - self.stop_count - self.active_count
- if queue_size > idle_threads:
- self.queue_logger.warning(
- "Task queue depth is %d", queue_size - idle_threads
- )
- def shutdown(self, cancel_pending=True, timeout=5):
- self.set_thread_count(0)
- # Ensure the threads shut down.
- threads = self.threads
- expiration = time.time() + timeout
- with self.lock:
- while threads:
- if time.time() >= expiration:
- self.logger.warning("%d thread(s) still running", len(threads))
- break
- self.thread_exit_cv.wait(0.1)
- if cancel_pending:
- # Cancel remaining tasks.
- queue = self.queue
- if len(queue) > 0:
- self.logger.warning("Canceling %d pending task(s)", len(queue))
- while queue:
- task = queue.popleft()
- task.cancel()
- self.queue_cv.notify_all()
- return True
- return False
- class Task:
- close_on_finish = False
- status = "200 OK"
- wrote_header = False
- start_time = 0
- content_length = None
- content_bytes_written = 0
- logged_write_excess = False
- logged_write_no_body = False
- complete = False
- chunked_response = False
- logger = logger
- def __init__(self, channel, request):
- self.channel = channel
- self.request = request
- self.response_headers = []
- version = request.version
- if version not in ("1.0", "1.1"):
- # fall back to a version we support.
- version = "1.0"
- self.version = version
- def service(self):
- try:
- self.start()
- self.execute()
- self.finish()
- except OSError:
- self.close_on_finish = True
- if self.channel.adj.log_socket_errors:
- raise
- @property
- def has_body(self):
- return not (
- self.status.startswith("1")
- or self.status.startswith("204")
- or self.status.startswith("304")
- )
- def build_response_header(self):
- version = self.version
- # Figure out whether the connection should be closed.
- connection = self.request.headers.get("CONNECTION", "").lower()
- response_headers = []
- content_length_header = None
- date_header = None
- server_header = None
- connection_close_header = None
- for (headername, headerval) in self.response_headers:
- headername = "-".join([x.capitalize() for x in headername.split("-")])
- if headername == "Content-Length":
- if self.has_body:
- content_length_header = headerval
- else:
- continue # pragma: no cover
- if headername == "Date":
- date_header = headerval
- if headername == "Server":
- server_header = headerval
- if headername == "Connection":
- connection_close_header = headerval.lower()
- # replace with properly capitalized version
- response_headers.append((headername, headerval))
- if (
- content_length_header is None
- and self.content_length is not None
- and self.has_body
- ):
- content_length_header = str(self.content_length)
- response_headers.append(("Content-Length", content_length_header))
- def close_on_finish():
- if connection_close_header is None:
- response_headers.append(("Connection", "close"))
- self.close_on_finish = True
- if version == "1.0":
- if connection == "keep-alive":
- if not content_length_header:
- close_on_finish()
- else:
- response_headers.append(("Connection", "Keep-Alive"))
- else:
- close_on_finish()
- elif version == "1.1":
- if connection == "close":
- close_on_finish()
- if not content_length_header:
- # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
- # for any response with a status code of 1xx, 204 or 304.
- if self.has_body:
- response_headers.append(("Transfer-Encoding", "chunked"))
- self.chunked_response = True
- if not self.close_on_finish:
- close_on_finish()
- # under HTTP 1.1 keep-alive is default, no need to set the header
- else:
- raise AssertionError("neither HTTP/1.0 or HTTP/1.1")
- # Set the Server and Date field, if not yet specified. This is needed
- # if the server is used as a proxy.
- ident = self.channel.server.adj.ident
- if not server_header:
- if ident:
- response_headers.append(("Server", ident))
- else:
- response_headers.append(("Via", ident or "waitress"))
- if not date_header:
- response_headers.append(("Date", build_http_date(self.start_time)))
- self.response_headers = response_headers
- first_line = f"HTTP/{self.version} {self.status}"
- # NB: sorting headers needs to preserve same-named-header order
- # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here;
- # rely on stable sort to keep relative position of same-named headers
- next_lines = [
- "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0])
- ]
- lines = [first_line] + next_lines
- res = "%s\r\n\r\n" % "\r\n".join(lines)
- return res.encode("latin-1")
- def remove_content_length_header(self):
- response_headers = []
- for header_name, header_value in self.response_headers:
- if header_name.lower() == "content-length":
- continue # pragma: nocover
- response_headers.append((header_name, header_value))
- self.response_headers = response_headers
- def start(self):
- self.start_time = time.time()
- def finish(self):
- if not self.wrote_header:
- self.write(b"")
- if self.chunked_response:
- # not self.write, it will chunk it!
- self.channel.write_soon(b"0\r\n\r\n")
- def write(self, data):
- if not self.complete:
- raise RuntimeError("start_response was not called before body written")
- channel = self.channel
- if not self.wrote_header:
- rh = self.build_response_header()
- channel.write_soon(rh)
- self.wrote_header = True
- if data and self.has_body:
- towrite = data
- cl = self.content_length
- if self.chunked_response:
- # use chunked encoding response
- towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n"
- towrite += data + b"\r\n"
- elif cl is not None:
- towrite = data[: cl - self.content_bytes_written]
- self.content_bytes_written += len(towrite)
- if towrite != data and not self.logged_write_excess:
- self.logger.warning(
- "application-written content exceeded the number of "
- "bytes specified by Content-Length header (%s)" % cl
- )
- self.logged_write_excess = True
- if towrite:
- channel.write_soon(towrite)
- elif data:
- # Cheat, and tell the application we have written all of the bytes,
- # even though the response shouldn't have a body and we are
- # ignoring it entirely.
- self.content_bytes_written += len(data)
- if not self.logged_write_no_body:
- self.logger.warning(
- "application-written content was ignored due to HTTP "
- "response that may not contain a message-body: (%s)" % self.status
- )
- self.logged_write_no_body = True
- class ErrorTask(Task):
- """An error task produces an error response"""
- complete = True
- def execute(self):
- e = self.request.error
- status, headers, body = e.to_response()
- self.status = status
- self.response_headers.extend(headers)
- # We need to explicitly tell the remote client we are closing the
- # connection, because self.close_on_finish is set, and we are going to
- # slam the door in the clients face.
- self.response_headers.append(("Connection", "close"))
- self.close_on_finish = True
- self.content_length = len(body)
- self.write(body)
- class WSGITask(Task):
- """A WSGI task produces a response from a WSGI application."""
- environ = None
- def execute(self):
- environ = self.get_environment()
- def start_response(status, headers, exc_info=None):
- if self.complete and not exc_info:
- raise AssertionError(
- "start_response called a second time without providing exc_info."
- )
- if exc_info:
- try:
- if self.wrote_header:
- # higher levels will catch and handle raised exception:
- # 1. "service" method in task.py
- # 2. "service" method in channel.py
- # 3. "handler_thread" method in task.py
- raise exc_info[1]
- else:
- # As per WSGI spec existing headers must be cleared
- self.response_headers = []
- finally:
- exc_info = None
- self.complete = True
- if not status.__class__ is str:
- raise AssertionError("status %s is not a string" % status)
- if "\n" in status or "\r" in status:
- raise ValueError(
- "carriage return/line feed character present in status"
- )
- self.status = status
- # Prepare the headers for output
- for k, v in headers:
- if not k.__class__ is str:
- raise AssertionError(
- f"Header name {k!r} is not a string in {(k, v)!r}"
- )
- if not v.__class__ is str:
- raise AssertionError(
- f"Header value {v!r} is not a string in {(k, v)!r}"
- )
- if "\n" in v or "\r" in v:
- raise ValueError(
- "carriage return/line feed character present in header value"
- )
- if "\n" in k or "\r" in k:
- raise ValueError(
- "carriage return/line feed character present in header name"
- )
- kl = k.lower()
- if kl == "content-length":
- self.content_length = int(v)
- elif kl in hop_by_hop:
- raise AssertionError(
- '%s is a "hop-by-hop" header; it cannot be used by '
- "a WSGI application (see PEP 3333)" % k
- )
- self.response_headers.extend(headers)
- # Return a method used to write the response data.
- return self.write
- # Call the application to handle the request and write a response
- app_iter = self.channel.server.application(environ, start_response)
- can_close_app_iter = True
- try:
- if app_iter.__class__ is ReadOnlyFileBasedBuffer:
- cl = self.content_length
- size = app_iter.prepare(cl)
- if size:
- if cl != size:
- if cl is not None:
- self.remove_content_length_header()
- self.content_length = size
- self.write(b"") # generate headers
- # if the write_soon below succeeds then the channel will
- # take over closing the underlying file via the channel's
- # _flush_some or handle_close so we intentionally avoid
- # calling close in the finally block
- self.channel.write_soon(app_iter)
- can_close_app_iter = False
- return
- first_chunk_len = None
- for chunk in app_iter:
- if first_chunk_len is None:
- first_chunk_len = len(chunk)
- # Set a Content-Length header if one is not supplied.
- # start_response may not have been called until first
- # iteration as per PEP, so we must reinterrogate
- # self.content_length here
- if self.content_length is None:
- app_iter_len = None
- if hasattr(app_iter, "__len__"):
- app_iter_len = len(app_iter)
- if app_iter_len == 1:
- self.content_length = first_chunk_len
- # transmit headers only after first iteration of the iterable
- # that returns a non-empty bytestring (PEP 3333)
- if chunk:
- self.write(chunk)
- cl = self.content_length
- if cl is not None:
- if self.content_bytes_written != cl:
- # close the connection so the client isn't sitting around
- # waiting for more data when there are too few bytes
- # to service content-length
- self.close_on_finish = True
- if self.request.command != "HEAD":
- self.logger.warning(
- "application returned too few bytes (%s) "
- "for specified Content-Length (%s) via app_iter"
- % (self.content_bytes_written, cl),
- )
- finally:
- if can_close_app_iter and hasattr(app_iter, "close"):
- app_iter.close()
- def get_environment(self):
- """Returns a WSGI environment."""
- environ = self.environ
- if environ is not None:
- # Return the cached copy.
- return environ
- request = self.request
- path = request.path
- channel = self.channel
- server = channel.server
- url_prefix = server.adj.url_prefix
- if path.startswith("/"):
- # strip extra slashes at the beginning of a path that starts
- # with any number of slashes
- path = "/" + path.lstrip("/")
- if url_prefix:
- # NB: url_prefix is guaranteed by the configuration machinery to
- # be either the empty string or a string that starts with a single
- # slash and ends without any slashes
- if path == url_prefix:
- # if the path is the same as the url prefix, the SCRIPT_NAME
- # should be the url_prefix and PATH_INFO should be empty
- path = ""
- else:
- # if the path starts with the url prefix plus a slash,
- # the SCRIPT_NAME should be the url_prefix and PATH_INFO should
- # the value of path from the slash until its end
- url_prefix_with_trailing_slash = url_prefix + "/"
- if path.startswith(url_prefix_with_trailing_slash):
- path = path[len(url_prefix) :]
- environ = {
- "REMOTE_ADDR": channel.addr[0],
- # Nah, we aren't actually going to look up the reverse DNS for
- # REMOTE_ADDR, but we will happily set this environment variable
- # for the WSGI application. Spec says we can just set this to
- # REMOTE_ADDR, so we do.
- "REMOTE_HOST": channel.addr[0],
- # try and set the REMOTE_PORT to something useful, but maybe None
- "REMOTE_PORT": str(channel.addr[1]),
- "REQUEST_METHOD": request.command.upper(),
- "SERVER_PORT": str(server.effective_port),
- "SERVER_NAME": server.server_name,
- "SERVER_SOFTWARE": server.adj.ident,
- "SERVER_PROTOCOL": "HTTP/%s" % self.version,
- "SCRIPT_NAME": url_prefix,
- "PATH_INFO": path,
- "REQUEST_URI": request.request_uri,
- "QUERY_STRING": request.query,
- "wsgi.url_scheme": request.url_scheme,
- # the following environment variables are required by the WSGI spec
- "wsgi.version": (1, 0),
- # apps should use the logging module
- "wsgi.errors": sys.stderr,
- "wsgi.multithread": True,
- "wsgi.multiprocess": False,
- "wsgi.run_once": False,
- "wsgi.input": request.get_body_stream(),
- "wsgi.file_wrapper": ReadOnlyFileBasedBuffer,
- "wsgi.input_terminated": True, # wsgi.input is EOF terminated
- }
- for key, value in dict(request.headers).items():
- value = value.strip()
- mykey = rename_headers.get(key, None)
- if mykey is None:
- mykey = "HTTP_" + key
- if mykey not in environ:
- environ[mykey] = value
- # Insert a callable into the environment that allows the application to
- # check if the client disconnected. Only works with
- # channel_request_lookahead larger than 0.
- environ["waitress.client_disconnected"] = self.channel.check_client_disconnected
- # cache the environ for this request
- self.environ = environ
- return environ
|