Merge pull request #428 from Unrud/patch-22

Add timeout to connections, limit size of request body and limit number of parallel connections
This commit is contained in:
Guillaume Ayoub
2016-07-14 02:06:24 +02:00
committed by GitHub
4 changed files with 65 additions and 18 deletions

View File

@@ -29,9 +29,11 @@ should have been included in this package.
import os
import pprint
import base64
import contextlib
import socket
import socketserver
import ssl
import threading
import wsgiref.simple_server
import re
import zlib
@@ -54,6 +56,11 @@ WELL_KNOWN_RE = re.compile(r"/\.well-known/(carddav|caldav)/?$")
class HTTPServer(wsgiref.simple_server.WSGIServer):
"""HTTP server."""
# These class attributes must be set before creating instance
client_timeout = None
max_connections = None
def __init__(self, address, handler, bind_and_activate=True):
"""Create server."""
ipv6 = ":" in address[0]
@@ -72,6 +79,20 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
self.server_bind()
self.server_activate()
if self.max_connections:
self.connections_guard = threading.BoundedSemaphore(
self.max_connections)
else:
# use dummy context manager
self.connections_guard = contextlib.suppress()
def get_request(self):
# Set timeout for client
_socket, address = super().get_request()
if self.client_timeout:
_socket.settimeout(self.client_timeout)
return _socket, address
class HTTPSServer(HTTPServer):
"""HTTPS server."""
@@ -95,11 +116,15 @@ class HTTPSServer(HTTPServer):
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
pass
def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
pass
def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
@@ -218,6 +243,15 @@ class Application:
def __call__(self, environ, start_response):
"""Manage a request."""
def response(status, headers={}, answer=None):
# Start response
status = "%i %s" % (status,
client.responses.get(status, "Unknown"))
self.logger.debug("Answer status: %s" % status)
start_response(status, list(headers.items()))
# Return response content
return [answer] if answer else []
self.logger.info("%s request at %s received" % (
environ["REQUEST_METHOD"], environ["PATH_INFO"]))
headers = pprint.pformat(self.headers_log(environ))
@@ -234,9 +268,7 @@ class Application:
# Request path not starting with base_prefix, not allowed
self.logger.debug(
"Path not starting with prefix: %s", environ["PATH_INFO"])
status, headers, _ = NOT_ALLOWED
start_response(status, list(headers.items()))
return []
return response(*NOT_ALLOWED)
# Sanitize request URI
environ["PATH_INFO"] = storage.sanitize_path(
@@ -275,10 +307,7 @@ class Application:
status = client.SEE_OTHER
self.logger.info("/.well-known/ redirection to: %s" % redirect)
headers = {"Location": redirect}
status = "%i %s" % (
status, client.responses.get(status, "Unknown"))
start_response(status, list(headers.items()))
return []
return response(status, headers)
is_authenticated = self.is_authenticated(user, password)
is_valid_user = is_authenticated or not user
@@ -286,8 +315,17 @@ class Application:
# Get content
content_length = int(environ.get("CONTENT_LENGTH") or 0)
if content_length:
content = self.decode(
environ["wsgi.input"].read(content_length), environ)
max_content_length = self.configuration.getint(
"server", "max_content_length")
if max_content_length and content_length > max_content_length:
self.logger.debug(
"Request body too large: %d", content_length)
return response(client.REQUEST_ENTITY_TOO_LARGE)
try:
content = self.decode(
environ["wsgi.input"].read(content_length), environ)
except socket.timeout:
return response(client.REQUEST_TIMEOUT)
self.logger.debug("Request content:\n%s" % content)
else:
content = None
@@ -345,13 +383,7 @@ class Application:
for key in self.configuration.options("headers"):
headers[key] = self.configuration.get("headers", key)
# Start response
status = "%i %s" % (status, client.responses.get(status, "Unknown"))
self.logger.debug("Answer status: %s" % status)
start_response(status, list(headers.items()))
# Return response content
return [answer] if answer else []
return response(status, headers, answer)
# All these functions must have the same parameters, some are useless
# pylint: disable=W0612,W0613,R0201