Limit number of parallel connections

This commit is contained in:
Unrud 2016-06-10 14:36:44 +02:00
parent e438d9fd4b
commit 83ea9da2b4
4 changed files with 22 additions and 2 deletions

3
config
View File

@ -24,6 +24,9 @@
# File storing the PID in daemon mode # File storing the PID in daemon mode
#pid = #pid =
# Max parallel connections
#max_connections = 20
# Max size of request body (bytes) # Max size of request body (bytes)
#max_content_length = 10000000 #max_content_length = 10000000

View File

@ -29,9 +29,11 @@ should have been included in this package.
import os import os
import pprint import pprint
import base64 import base64
import contextlib
import socket import socket
import socketserver import socketserver
import ssl import ssl
import threading
import wsgiref.simple_server import wsgiref.simple_server
import re import re
import zlib import zlib
@ -57,6 +59,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
# These class attributes must be set before creating instance # These class attributes must be set before creating instance
client_timeout = None client_timeout = None
max_connections = None
def __init__(self, address, handler, bind_and_activate=True): def __init__(self, address, handler, bind_and_activate=True):
"""Create server.""" """Create server."""
@ -76,6 +79,13 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
self.server_bind() self.server_bind()
self.server_activate() self.server_activate()
if self.max_connections:
self.connections_guard = threading.BoundedSemaphore(
self.max_connections)
else:
# use dummy context manager
self.connections_guard = contextlib.suppress()
def get_request(self): def get_request(self):
# Set timeout for client # Set timeout for client
_socket, address = super().get_request() _socket, address = super().get_request()
@ -106,11 +116,15 @@ class HTTPSServer(HTTPServer):
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer): class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
pass def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer): class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
pass def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler): class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):

View File

@ -171,6 +171,8 @@ def run():
else: else:
server_class = ThreadedHTTPServer server_class = ThreadedHTTPServer
server_class.client_timeout = configuration.getint("server", "timeout") server_class.client_timeout = configuration.getint("server", "timeout")
server_class.max_connections = configuration.getint("server",
"max_connections")
if not configuration.getboolean("server", "dns_lookup"): if not configuration.getboolean("server", "dns_lookup"):
RequestHandler.address_string = lambda self: self.client_address[0] RequestHandler.address_string = lambda self: self.client_address[0]

View File

@ -34,6 +34,7 @@ INITIAL_CONFIG = {
"hosts": "0.0.0.0:5232", "hosts": "0.0.0.0:5232",
"daemon": "False", "daemon": "False",
"pid": "", "pid": "",
"max_connections": "20",
"max_content_length": "10000000", "max_content_length": "10000000",
"timeout": "10", "timeout": "10",
"ssl": "False", "ssl": "False",