Use forking for internal server when available

This commit is contained in:
Unrud 2018-08-18 12:56:41 +02:00
parent ddd99a5329
commit 30a9ecc06b
2 changed files with 60 additions and 27 deletions

View File

@ -25,6 +25,7 @@ http://docs.python.org/library/logging.config.html
import contextlib
import io
import logging
import multiprocessing
import os
import sys
import threading
@ -35,7 +36,7 @@ except ImportError:
journal = None
LOGGER_NAME = "radicale"
LOGGER_FORMAT = "[%(processName)s/%(threadName)s] %(levelname)s: %(message)s"
LOGGER_FORMAT = "[%(ident)s] %(levelname)s: %(message)s"
root_logger = logging.getLogger()
logger = logging.getLogger(LOGGER_NAME)
@ -50,6 +51,27 @@ class RemoveTracebackFilter(logging.Filter):
removeTracebackFilter = RemoveTracebackFilter()
class IdentLogRecordFactory:
"""LogRecordFactory that adds ``ident`` attribute."""
def __init__(self, upstream_factory):
self.upstream_factory = upstream_factory
self.main_pid = os.getpid()
self.main_thread_name = threading.current_thread().name
def __call__(self, *args, **kwargs):
record = self.upstream_factory(*args, **kwargs)
pid = os.getpid()
thread_name = threading.current_thread().name
ident = "%x" % self.main_pid
if pid != self.main_pid:
ident += "%+x" % (pid - self.main_pid)
if thread_name != self.main_thread_name:
ident += "/%s" % thread_name
record.ident = ident
return record
class ThreadStreamsHandler(logging.Handler):
terminator = "\n"
@ -60,6 +82,9 @@ class ThreadStreamsHandler(logging.Handler):
self.fallback_stream = fallback_stream
self.fallback_handler = fallback_handler
def createLock(self):
self.lock = multiprocessing.Lock()
def setFormatter(self, form):
super().setFormatter(form)
self.fallback_handler.setFormatter(form)
@ -116,6 +141,8 @@ def setup():
handler = ThreadStreamsHandler(sys.stderr, get_default_handler())
logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler])
register_stream = handler.register_stream
log_record_factory = IdentLogRecordFactory(logging.getLogRecordFactory())
logging.setLogRecordFactory(log_record_factory)
set_level(logging.DEBUG)

View File

@ -22,6 +22,7 @@ Radicale WSGI server.
"""
import contextlib
import multiprocessing
import os
import select
import signal
@ -29,16 +30,20 @@ import socket
import socketserver
import ssl
import sys
import threading
import wsgiref.simple_server
from urllib.parse import unquote
from radicale import Application
from radicale.log import logger
if hasattr(socketserver, "ForkingMixIn"):
ParallelizationMixIn = socketserver.ForkingMixIn
else:
ParallelizationMixIn = socketserver.ThreadingMixIn
class HTTPServer(wsgiref.simple_server.WSGIServer):
"""HTTP server."""
class ParallelHTTPServer(ParallelizationMixIn,
wsgiref.simple_server.WSGIServer):
# These class attributes must be set before creating instance
client_timeout = None
@ -59,7 +64,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
if self.max_connections:
self.connections_guard = threading.BoundedSemaphore(
self.connections_guard = multiprocessing.BoundedSemaphore(
self.max_connections)
else:
# use dummy context manager
@ -75,10 +80,14 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
def get_request(self):
# Set timeout for client
_socket, address = super().get_request()
socket_, address = super().get_request()
if self.client_timeout:
_socket.settimeout(self.client_timeout)
return _socket, address
socket_.settimeout(self.client_timeout)
return socket_, address
def finish_request(self, request, client_address):
with self.connections_guard:
return super().finish_request(request, client_address)
def handle_error(self, request, client_address):
if issubclass(sys.exc_info()[0], socket.timeout):
@ -88,8 +97,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
sys.exc_info()[1], exc_info=True)
class HTTPSServer(HTTPServer):
"""HTTPS server."""
class ParallelHTTPSServer(ParallelHTTPServer):
# These class attributes must be set before creating instance
certificate = None
@ -98,9 +106,11 @@ class HTTPSServer(HTTPServer):
ciphers = None
certificate_authority = None
def __init__(self, address, handler):
def __init__(self, address, handler, bind_and_activate=True):
"""Create server by wrapping HTTP socket in an SSL socket."""
super().__init__(address, handler, bind_and_activate=False)
# Do not bind and activate, as we change the socket
super().__init__(address, handler, False)
self.socket = ssl.wrap_socket(
self.socket, self.key, self.certificate, server_side=True,
@ -110,18 +120,15 @@ class HTTPSServer(HTTPServer):
ssl_version=self.protocol, ciphers=self.ciphers,
do_handshake_on_connect=False)
self.server_bind()
self.server_activate()
if bind_and_activate:
try:
self.server_bind()
self.server_activate()
except BaseException:
self.server_close()
raise
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
def process_request_thread(self, request, client_address):
def finish_request(self, request, client_address):
try:
try:
request.do_handshake()
@ -135,8 +142,7 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
finally:
self.shutdown_request(request)
return
with self.connections_guard:
return super().process_request_thread(request, client_address)
return super().finish_request(request, client_address)
class ServerHandler(wsgiref.simple_server.ServerHandler):
@ -197,7 +203,7 @@ def serve(configuration):
# Create collection servers
servers = {}
if configuration.getboolean("server", "ssl"):
server_class = ThreadedHTTPSServer
server_class = ParallelHTTPSServer
server_class.certificate = configuration.get("server", "certificate")
server_class.key = configuration.get("server", "key")
server_class.certificate_authority = configuration.get(
@ -216,7 +222,7 @@ def serve(configuration):
raise RuntimeError("Failed to read SSL %s %r: %s" %
(name, filename, e)) from e
else:
server_class = ThreadedHTTPServer
server_class = ParallelHTTPServer
server_class.client_timeout = configuration.getint("server", "timeout")
server_class.max_connections = configuration.getint(
"server", "max_connections")