Merge pull request #428 from Unrud/patch-22

Add timeout to connections, limit size of request body and limit number of parallel connections
This commit is contained in:
Guillaume Ayoub 2016-07-14 02:06:24 +02:00 committed by GitHub
commit ef63865e31
4 changed files with 65 additions and 18 deletions

9
config
View File

@ -24,6 +24,15 @@
# File storing the PID in daemon mode # File storing the PID in daemon mode
#pid = #pid =
# Max parallel connections
#max_connections = 20
# Max size of request body (bytes)
#max_content_length = 10000000
# Socket timeout (seconds)
#timeout = 10
# SSL flag, enable HTTPS protocol # SSL flag, enable HTTPS protocol
#ssl = False #ssl = False

View File

@ -29,9 +29,11 @@ should have been included in this package.
import os import os
import pprint import pprint
import base64 import base64
import contextlib
import socket import socket
import socketserver import socketserver
import ssl import ssl
import threading
import wsgiref.simple_server import wsgiref.simple_server
import re import re
import zlib import zlib
@ -54,6 +56,11 @@ WELL_KNOWN_RE = re.compile(r"/\.well-known/(carddav|caldav)/?$")
class HTTPServer(wsgiref.simple_server.WSGIServer): class HTTPServer(wsgiref.simple_server.WSGIServer):
"""HTTP server.""" """HTTP server."""
# These class attributes must be set before creating instance
client_timeout = None
max_connections = None
def __init__(self, address, handler, bind_and_activate=True): def __init__(self, address, handler, bind_and_activate=True):
"""Create server.""" """Create server."""
ipv6 = ":" in address[0] ipv6 = ":" in address[0]
@ -72,6 +79,20 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
self.server_bind() self.server_bind()
self.server_activate() self.server_activate()
if self.max_connections:
self.connections_guard = threading.BoundedSemaphore(
self.max_connections)
else:
# use dummy context manager
self.connections_guard = contextlib.suppress()
def get_request(self):
# Set timeout for client
_socket, address = super().get_request()
if self.client_timeout:
_socket.settimeout(self.client_timeout)
return _socket, address
class HTTPSServer(HTTPServer): class HTTPSServer(HTTPServer):
"""HTTPS server.""" """HTTPS server."""
@ -95,11 +116,15 @@ class HTTPSServer(HTTPServer):
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer): class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
pass def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer): class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
pass def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler): class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
@ -218,6 +243,15 @@ class Application:
def __call__(self, environ, start_response): def __call__(self, environ, start_response):
"""Manage a request.""" """Manage a request."""
def response(status, headers={}, answer=None):
# Start response
status = "%i %s" % (status,
client.responses.get(status, "Unknown"))
self.logger.debug("Answer status: %s" % status)
start_response(status, list(headers.items()))
# Return response content
return [answer] if answer else []
self.logger.info("%s request at %s received" % ( self.logger.info("%s request at %s received" % (
environ["REQUEST_METHOD"], environ["PATH_INFO"])) environ["REQUEST_METHOD"], environ["PATH_INFO"]))
headers = pprint.pformat(self.headers_log(environ)) headers = pprint.pformat(self.headers_log(environ))
@ -234,9 +268,7 @@ class Application:
# Request path not starting with base_prefix, not allowed # Request path not starting with base_prefix, not allowed
self.logger.debug( self.logger.debug(
"Path not starting with prefix: %s", environ["PATH_INFO"]) "Path not starting with prefix: %s", environ["PATH_INFO"])
status, headers, _ = NOT_ALLOWED return response(*NOT_ALLOWED)
start_response(status, list(headers.items()))
return []
# Sanitize request URI # Sanitize request URI
environ["PATH_INFO"] = storage.sanitize_path( environ["PATH_INFO"] = storage.sanitize_path(
@ -275,10 +307,7 @@ class Application:
status = client.SEE_OTHER status = client.SEE_OTHER
self.logger.info("/.well-known/ redirection to: %s" % redirect) self.logger.info("/.well-known/ redirection to: %s" % redirect)
headers = {"Location": redirect} headers = {"Location": redirect}
status = "%i %s" % ( return response(status, headers)
status, client.responses.get(status, "Unknown"))
start_response(status, list(headers.items()))
return []
is_authenticated = self.is_authenticated(user, password) is_authenticated = self.is_authenticated(user, password)
is_valid_user = is_authenticated or not user is_valid_user = is_authenticated or not user
@ -286,8 +315,17 @@ class Application:
# Get content # Get content
content_length = int(environ.get("CONTENT_LENGTH") or 0) content_length = int(environ.get("CONTENT_LENGTH") or 0)
if content_length: if content_length:
content = self.decode( max_content_length = self.configuration.getint(
environ["wsgi.input"].read(content_length), environ) "server", "max_content_length")
if max_content_length and content_length > max_content_length:
self.logger.debug(
"Request body too large: %d", content_length)
return response(client.REQUEST_ENTITY_TOO_LARGE)
try:
content = self.decode(
environ["wsgi.input"].read(content_length), environ)
except socket.timeout:
return response(client.REQUEST_TIMEOUT)
self.logger.debug("Request content:\n%s" % content) self.logger.debug("Request content:\n%s" % content)
else: else:
content = None content = None
@ -345,13 +383,7 @@ class Application:
for key in self.configuration.options("headers"): for key in self.configuration.options("headers"):
headers[key] = self.configuration.get("headers", key) headers[key] = self.configuration.get("headers", key)
# Start response return response(status, headers, answer)
status = "%i %s" % (status, client.responses.get(status, "Unknown"))
self.logger.debug("Answer status: %s" % status)
start_response(status, list(headers.items()))
# Return response content
return [answer] if answer else []
# All these functions must have the same parameters, some are useless # All these functions must have the same parameters, some are useless
# pylint: disable=W0612,W0613,R0201 # pylint: disable=W0612,W0613,R0201

View File

@ -175,6 +175,9 @@ def serve(configuration, logger):
name, filename, exception)) name, filename, exception))
else: else:
server_class = ThreadedHTTPServer server_class = ThreadedHTTPServer
server_class.client_timeout = configuration.getint("server", "timeout")
server_class.max_connections = configuration.getint("server",
"max_connections")
if not configuration.getboolean("server", "dns_lookup"): if not configuration.getboolean("server", "dns_lookup"):
RequestHandler.address_string = lambda self: self.client_address[0] RequestHandler.address_string = lambda self: self.client_address[0]

View File

@ -32,6 +32,9 @@ INITIAL_CONFIG = {
"hosts": "0.0.0.0:5232", "hosts": "0.0.0.0:5232",
"daemon": "False", "daemon": "False",
"pid": "", "pid": "",
"max_connections": "20",
"max_content_length": "10000000",
"timeout": "10",
"ssl": "False", "ssl": "False",
"certificate": "/etc/apache2/ssl/server.crt", "certificate": "/etc/apache2/ssl/server.crt",
"key": "/etc/apache2/ssl/server.key", "key": "/etc/apache2/ssl/server.key",