Merge pull request #428 from Unrud/patch-22
Add timeout to connections, limit size of request body and limit number of parallel connections
This commit is contained in:
commit
ef63865e31
9
config
9
config
@ -24,6 +24,15 @@
|
|||||||
# File storing the PID in daemon mode
|
# File storing the PID in daemon mode
|
||||||
#pid =
|
#pid =
|
||||||
|
|
||||||
|
# Max parallel connections
|
||||||
|
#max_connections = 20
|
||||||
|
|
||||||
|
# Max size of request body (bytes)
|
||||||
|
#max_content_length = 10000000
|
||||||
|
|
||||||
|
# Socket timeout (seconds)
|
||||||
|
#timeout = 10
|
||||||
|
|
||||||
# SSL flag, enable HTTPS protocol
|
# SSL flag, enable HTTPS protocol
|
||||||
#ssl = False
|
#ssl = False
|
||||||
|
|
||||||
|
@ -29,9 +29,11 @@ should have been included in this package.
|
|||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
import base64
|
import base64
|
||||||
|
import contextlib
|
||||||
import socket
|
import socket
|
||||||
import socketserver
|
import socketserver
|
||||||
import ssl
|
import ssl
|
||||||
|
import threading
|
||||||
import wsgiref.simple_server
|
import wsgiref.simple_server
|
||||||
import re
|
import re
|
||||||
import zlib
|
import zlib
|
||||||
@ -54,6 +56,11 @@ WELL_KNOWN_RE = re.compile(r"/\.well-known/(carddav|caldav)/?$")
|
|||||||
|
|
||||||
class HTTPServer(wsgiref.simple_server.WSGIServer):
|
class HTTPServer(wsgiref.simple_server.WSGIServer):
|
||||||
"""HTTP server."""
|
"""HTTP server."""
|
||||||
|
|
||||||
|
# These class attributes must be set before creating instance
|
||||||
|
client_timeout = None
|
||||||
|
max_connections = None
|
||||||
|
|
||||||
def __init__(self, address, handler, bind_and_activate=True):
|
def __init__(self, address, handler, bind_and_activate=True):
|
||||||
"""Create server."""
|
"""Create server."""
|
||||||
ipv6 = ":" in address[0]
|
ipv6 = ":" in address[0]
|
||||||
@ -72,6 +79,20 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
|
|||||||
self.server_bind()
|
self.server_bind()
|
||||||
self.server_activate()
|
self.server_activate()
|
||||||
|
|
||||||
|
if self.max_connections:
|
||||||
|
self.connections_guard = threading.BoundedSemaphore(
|
||||||
|
self.max_connections)
|
||||||
|
else:
|
||||||
|
# use dummy context manager
|
||||||
|
self.connections_guard = contextlib.suppress()
|
||||||
|
|
||||||
|
def get_request(self):
|
||||||
|
# Set timeout for client
|
||||||
|
_socket, address = super().get_request()
|
||||||
|
if self.client_timeout:
|
||||||
|
_socket.settimeout(self.client_timeout)
|
||||||
|
return _socket, address
|
||||||
|
|
||||||
|
|
||||||
class HTTPSServer(HTTPServer):
|
class HTTPSServer(HTTPServer):
|
||||||
"""HTTPS server."""
|
"""HTTPS server."""
|
||||||
@ -95,11 +116,15 @@ class HTTPSServer(HTTPServer):
|
|||||||
|
|
||||||
|
|
||||||
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
|
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
|
||||||
pass
|
def process_request_thread(self, request, client_address):
|
||||||
|
with self.connections_guard:
|
||||||
|
return super().process_request_thread(request, client_address)
|
||||||
|
|
||||||
|
|
||||||
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
||||||
pass
|
def process_request_thread(self, request, client_address):
|
||||||
|
with self.connections_guard:
|
||||||
|
return super().process_request_thread(request, client_address)
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
||||||
@ -218,6 +243,15 @@ class Application:
|
|||||||
|
|
||||||
def __call__(self, environ, start_response):
|
def __call__(self, environ, start_response):
|
||||||
"""Manage a request."""
|
"""Manage a request."""
|
||||||
|
def response(status, headers={}, answer=None):
|
||||||
|
# Start response
|
||||||
|
status = "%i %s" % (status,
|
||||||
|
client.responses.get(status, "Unknown"))
|
||||||
|
self.logger.debug("Answer status: %s" % status)
|
||||||
|
start_response(status, list(headers.items()))
|
||||||
|
# Return response content
|
||||||
|
return [answer] if answer else []
|
||||||
|
|
||||||
self.logger.info("%s request at %s received" % (
|
self.logger.info("%s request at %s received" % (
|
||||||
environ["REQUEST_METHOD"], environ["PATH_INFO"]))
|
environ["REQUEST_METHOD"], environ["PATH_INFO"]))
|
||||||
headers = pprint.pformat(self.headers_log(environ))
|
headers = pprint.pformat(self.headers_log(environ))
|
||||||
@ -234,9 +268,7 @@ class Application:
|
|||||||
# Request path not starting with base_prefix, not allowed
|
# Request path not starting with base_prefix, not allowed
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
"Path not starting with prefix: %s", environ["PATH_INFO"])
|
"Path not starting with prefix: %s", environ["PATH_INFO"])
|
||||||
status, headers, _ = NOT_ALLOWED
|
return response(*NOT_ALLOWED)
|
||||||
start_response(status, list(headers.items()))
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Sanitize request URI
|
# Sanitize request URI
|
||||||
environ["PATH_INFO"] = storage.sanitize_path(
|
environ["PATH_INFO"] = storage.sanitize_path(
|
||||||
@ -275,10 +307,7 @@ class Application:
|
|||||||
status = client.SEE_OTHER
|
status = client.SEE_OTHER
|
||||||
self.logger.info("/.well-known/ redirection to: %s" % redirect)
|
self.logger.info("/.well-known/ redirection to: %s" % redirect)
|
||||||
headers = {"Location": redirect}
|
headers = {"Location": redirect}
|
||||||
status = "%i %s" % (
|
return response(status, headers)
|
||||||
status, client.responses.get(status, "Unknown"))
|
|
||||||
start_response(status, list(headers.items()))
|
|
||||||
return []
|
|
||||||
|
|
||||||
is_authenticated = self.is_authenticated(user, password)
|
is_authenticated = self.is_authenticated(user, password)
|
||||||
is_valid_user = is_authenticated or not user
|
is_valid_user = is_authenticated or not user
|
||||||
@ -286,8 +315,17 @@ class Application:
|
|||||||
# Get content
|
# Get content
|
||||||
content_length = int(environ.get("CONTENT_LENGTH") or 0)
|
content_length = int(environ.get("CONTENT_LENGTH") or 0)
|
||||||
if content_length:
|
if content_length:
|
||||||
content = self.decode(
|
max_content_length = self.configuration.getint(
|
||||||
environ["wsgi.input"].read(content_length), environ)
|
"server", "max_content_length")
|
||||||
|
if max_content_length and content_length > max_content_length:
|
||||||
|
self.logger.debug(
|
||||||
|
"Request body too large: %d", content_length)
|
||||||
|
return response(client.REQUEST_ENTITY_TOO_LARGE)
|
||||||
|
try:
|
||||||
|
content = self.decode(
|
||||||
|
environ["wsgi.input"].read(content_length), environ)
|
||||||
|
except socket.timeout:
|
||||||
|
return response(client.REQUEST_TIMEOUT)
|
||||||
self.logger.debug("Request content:\n%s" % content)
|
self.logger.debug("Request content:\n%s" % content)
|
||||||
else:
|
else:
|
||||||
content = None
|
content = None
|
||||||
@ -345,13 +383,7 @@ class Application:
|
|||||||
for key in self.configuration.options("headers"):
|
for key in self.configuration.options("headers"):
|
||||||
headers[key] = self.configuration.get("headers", key)
|
headers[key] = self.configuration.get("headers", key)
|
||||||
|
|
||||||
# Start response
|
return response(status, headers, answer)
|
||||||
status = "%i %s" % (status, client.responses.get(status, "Unknown"))
|
|
||||||
self.logger.debug("Answer status: %s" % status)
|
|
||||||
start_response(status, list(headers.items()))
|
|
||||||
|
|
||||||
# Return response content
|
|
||||||
return [answer] if answer else []
|
|
||||||
|
|
||||||
# All these functions must have the same parameters, some are useless
|
# All these functions must have the same parameters, some are useless
|
||||||
# pylint: disable=W0612,W0613,R0201
|
# pylint: disable=W0612,W0613,R0201
|
||||||
|
@ -175,6 +175,9 @@ def serve(configuration, logger):
|
|||||||
name, filename, exception))
|
name, filename, exception))
|
||||||
else:
|
else:
|
||||||
server_class = ThreadedHTTPServer
|
server_class = ThreadedHTTPServer
|
||||||
|
server_class.client_timeout = configuration.getint("server", "timeout")
|
||||||
|
server_class.max_connections = configuration.getint("server",
|
||||||
|
"max_connections")
|
||||||
|
|
||||||
if not configuration.getboolean("server", "dns_lookup"):
|
if not configuration.getboolean("server", "dns_lookup"):
|
||||||
RequestHandler.address_string = lambda self: self.client_address[0]
|
RequestHandler.address_string = lambda self: self.client_address[0]
|
||||||
|
@ -32,6 +32,9 @@ INITIAL_CONFIG = {
|
|||||||
"hosts": "0.0.0.0:5232",
|
"hosts": "0.0.0.0:5232",
|
||||||
"daemon": "False",
|
"daemon": "False",
|
||||||
"pid": "",
|
"pid": "",
|
||||||
|
"max_connections": "20",
|
||||||
|
"max_content_length": "10000000",
|
||||||
|
"timeout": "10",
|
||||||
"ssl": "False",
|
"ssl": "False",
|
||||||
"certificate": "/etc/apache2/ssl/server.crt",
|
"certificate": "/etc/apache2/ssl/server.crt",
|
||||||
"key": "/etc/apache2/ssl/server.key",
|
"key": "/etc/apache2/ssl/server.key",
|
||||||
|
Loading…
Reference in New Issue
Block a user