diff --git a/radicale/log.py b/radicale/log.py index 0304131..e3df39d 100644 --- a/radicale/log.py +++ b/radicale/log.py @@ -25,6 +25,7 @@ http://docs.python.org/library/logging.config.html import contextlib import io import logging +import multiprocessing import os import sys import threading @@ -35,7 +36,7 @@ except ImportError: journal = None LOGGER_NAME = "radicale" -LOGGER_FORMAT = "[%(processName)s/%(threadName)s] %(levelname)s: %(message)s" +LOGGER_FORMAT = "[%(ident)s] %(levelname)s: %(message)s" root_logger = logging.getLogger() logger = logging.getLogger(LOGGER_NAME) @@ -50,6 +51,27 @@ class RemoveTracebackFilter(logging.Filter): removeTracebackFilter = RemoveTracebackFilter() +class IdentLogRecordFactory: + """LogRecordFactory that adds ``ident`` attribute.""" + + def __init__(self, upstream_factory): + self.upstream_factory = upstream_factory + self.main_pid = os.getpid() + self.main_thread_name = threading.current_thread().name + + def __call__(self, *args, **kwargs): + record = self.upstream_factory(*args, **kwargs) + pid = os.getpid() + thread_name = threading.current_thread().name + ident = "%x" % self.main_pid + if pid != self.main_pid: + ident += "%+x" % (pid - self.main_pid) + if thread_name != self.main_thread_name: + ident += "/%s" % thread_name + record.ident = ident + return record + + class ThreadStreamsHandler(logging.Handler): terminator = "\n" @@ -60,6 +82,9 @@ class ThreadStreamsHandler(logging.Handler): self.fallback_stream = fallback_stream self.fallback_handler = fallback_handler + def createLock(self): + self.lock = multiprocessing.Lock() + def setFormatter(self, form): super().setFormatter(form) self.fallback_handler.setFormatter(form) @@ -116,6 +141,8 @@ def setup(): handler = ThreadStreamsHandler(sys.stderr, get_default_handler()) logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler]) register_stream = handler.register_stream + log_record_factory = IdentLogRecordFactory(logging.getLogRecordFactory()) + logging.setLogRecordFactory(log_record_factory) set_level(logging.DEBUG) diff --git a/radicale/server.py b/radicale/server.py index 3ecd90d..700b545 100644 --- a/radicale/server.py +++ b/radicale/server.py @@ -22,6 +22,7 @@ Radicale WSGI server. """ import contextlib +import multiprocessing import os import select import signal @@ -29,16 +30,20 @@ import socket import socketserver import ssl import sys -import threading import wsgiref.simple_server from urllib.parse import unquote from radicale import Application from radicale.log import logger +if hasattr(socketserver, "ForkingMixIn"): + ParallelizationMixIn = socketserver.ForkingMixIn +else: + ParallelizationMixIn = socketserver.ThreadingMixIn -class HTTPServer(wsgiref.simple_server.WSGIServer): - """HTTP server.""" + +class ParallelHTTPServer(ParallelizationMixIn, + wsgiref.simple_server.WSGIServer): # These class attributes must be set before creating instance client_timeout = None @@ -59,7 +64,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer): self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) if self.max_connections: - self.connections_guard = threading.BoundedSemaphore( + self.connections_guard = multiprocessing.BoundedSemaphore( self.max_connections) else: # use dummy context manager @@ -75,10 +80,14 @@ class HTTPServer(wsgiref.simple_server.WSGIServer): def get_request(self): # Set timeout for client - _socket, address = super().get_request() + socket_, address = super().get_request() if self.client_timeout: - _socket.settimeout(self.client_timeout) - return _socket, address + socket_.settimeout(self.client_timeout) + return socket_, address + + def finish_request(self, request, client_address): + with self.connections_guard: + return super().finish_request(request, client_address) def handle_error(self, request, client_address): if issubclass(sys.exc_info()[0], socket.timeout): @@ -88,8 +97,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer): sys.exc_info()[1], exc_info=True) -class HTTPSServer(HTTPServer): - """HTTPS server.""" +class ParallelHTTPSServer(ParallelHTTPServer): # These class attributes must be set before creating instance certificate = None @@ -98,9 +106,11 @@ class HTTPSServer(HTTPServer): ciphers = None certificate_authority = None - def __init__(self, address, handler): + def __init__(self, address, handler, bind_and_activate=True): """Create server by wrapping HTTP socket in an SSL socket.""" - super().__init__(address, handler, bind_and_activate=False) + + # Do not bind and activate, as we change the socket + super().__init__(address, handler, False) self.socket = ssl.wrap_socket( self.socket, self.key, self.certificate, server_side=True, @@ -110,18 +120,15 @@ class HTTPSServer(HTTPServer): ssl_version=self.protocol, ciphers=self.ciphers, do_handshake_on_connect=False) - self.server_bind() - self.server_activate() + if bind_and_activate: + try: + self.server_bind() + self.server_activate() + except BaseException: + self.server_close() + raise - -class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer): - def process_request_thread(self, request, client_address): - with self.connections_guard: - return super().process_request_thread(request, client_address) - - -class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer): - def process_request_thread(self, request, client_address): + def finish_request(self, request, client_address): try: try: request.do_handshake() @@ -135,8 +142,7 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer): finally: self.shutdown_request(request) return - with self.connections_guard: - return super().process_request_thread(request, client_address) + return super().finish_request(request, client_address) class ServerHandler(wsgiref.simple_server.ServerHandler): @@ -197,7 +203,7 @@ def serve(configuration): # Create collection servers servers = {} if configuration.getboolean("server", "ssl"): - server_class = ThreadedHTTPSServer + server_class = ParallelHTTPSServer server_class.certificate = configuration.get("server", "certificate") server_class.key = configuration.get("server", "key") server_class.certificate_authority = configuration.get( @@ -216,7 +222,7 @@ def serve(configuration): raise RuntimeError("Failed to read SSL %s %r: %s" % (name, filename, e)) from e else: - server_class = ThreadedHTTPServer + server_class = ParallelHTTPServer server_class.client_timeout = configuration.getint("server", "timeout") server_class.max_connections = configuration.getint( "server", "max_connections")