Use forking for internal server when available

This commit is contained in:
Unrud 2018-08-18 12:56:41 +02:00
parent ddd99a5329
commit 30a9ecc06b
2 changed files with 60 additions and 27 deletions

View File

@ -25,6 +25,7 @@ http://docs.python.org/library/logging.config.html
import contextlib import contextlib
import io import io
import logging import logging
import multiprocessing
import os import os
import sys import sys
import threading import threading
@ -35,7 +36,7 @@ except ImportError:
journal = None journal = None
LOGGER_NAME = "radicale" LOGGER_NAME = "radicale"
LOGGER_FORMAT = "[%(processName)s/%(threadName)s] %(levelname)s: %(message)s" LOGGER_FORMAT = "[%(ident)s] %(levelname)s: %(message)s"
root_logger = logging.getLogger() root_logger = logging.getLogger()
logger = logging.getLogger(LOGGER_NAME) logger = logging.getLogger(LOGGER_NAME)
@ -50,6 +51,27 @@ class RemoveTracebackFilter(logging.Filter):
removeTracebackFilter = RemoveTracebackFilter() removeTracebackFilter = RemoveTracebackFilter()
class IdentLogRecordFactory:
"""LogRecordFactory that adds ``ident`` attribute."""
def __init__(self, upstream_factory):
self.upstream_factory = upstream_factory
self.main_pid = os.getpid()
self.main_thread_name = threading.current_thread().name
def __call__(self, *args, **kwargs):
record = self.upstream_factory(*args, **kwargs)
pid = os.getpid()
thread_name = threading.current_thread().name
ident = "%x" % self.main_pid
if pid != self.main_pid:
ident += "%+x" % (pid - self.main_pid)
if thread_name != self.main_thread_name:
ident += "/%s" % thread_name
record.ident = ident
return record
class ThreadStreamsHandler(logging.Handler): class ThreadStreamsHandler(logging.Handler):
terminator = "\n" terminator = "\n"
@ -60,6 +82,9 @@ class ThreadStreamsHandler(logging.Handler):
self.fallback_stream = fallback_stream self.fallback_stream = fallback_stream
self.fallback_handler = fallback_handler self.fallback_handler = fallback_handler
def createLock(self):
self.lock = multiprocessing.Lock()
def setFormatter(self, form): def setFormatter(self, form):
super().setFormatter(form) super().setFormatter(form)
self.fallback_handler.setFormatter(form) self.fallback_handler.setFormatter(form)
@ -116,6 +141,8 @@ def setup():
handler = ThreadStreamsHandler(sys.stderr, get_default_handler()) handler = ThreadStreamsHandler(sys.stderr, get_default_handler())
logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler]) logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler])
register_stream = handler.register_stream register_stream = handler.register_stream
log_record_factory = IdentLogRecordFactory(logging.getLogRecordFactory())
logging.setLogRecordFactory(log_record_factory)
set_level(logging.DEBUG) set_level(logging.DEBUG)

View File

@ -22,6 +22,7 @@ Radicale WSGI server.
""" """
import contextlib import contextlib
import multiprocessing
import os import os
import select import select
import signal import signal
@ -29,16 +30,20 @@ import socket
import socketserver import socketserver
import ssl import ssl
import sys import sys
import threading
import wsgiref.simple_server import wsgiref.simple_server
from urllib.parse import unquote from urllib.parse import unquote
from radicale import Application from radicale import Application
from radicale.log import logger from radicale.log import logger
if hasattr(socketserver, "ForkingMixIn"):
ParallelizationMixIn = socketserver.ForkingMixIn
else:
ParallelizationMixIn = socketserver.ThreadingMixIn
class HTTPServer(wsgiref.simple_server.WSGIServer):
"""HTTP server.""" class ParallelHTTPServer(ParallelizationMixIn,
wsgiref.simple_server.WSGIServer):
# These class attributes must be set before creating instance # These class attributes must be set before creating instance
client_timeout = None client_timeout = None
@ -59,7 +64,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
if self.max_connections: if self.max_connections:
self.connections_guard = threading.BoundedSemaphore( self.connections_guard = multiprocessing.BoundedSemaphore(
self.max_connections) self.max_connections)
else: else:
# use dummy context manager # use dummy context manager
@ -75,10 +80,14 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
def get_request(self): def get_request(self):
# Set timeout for client # Set timeout for client
_socket, address = super().get_request() socket_, address = super().get_request()
if self.client_timeout: if self.client_timeout:
_socket.settimeout(self.client_timeout) socket_.settimeout(self.client_timeout)
return _socket, address return socket_, address
def finish_request(self, request, client_address):
with self.connections_guard:
return super().finish_request(request, client_address)
def handle_error(self, request, client_address): def handle_error(self, request, client_address):
if issubclass(sys.exc_info()[0], socket.timeout): if issubclass(sys.exc_info()[0], socket.timeout):
@ -88,8 +97,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
sys.exc_info()[1], exc_info=True) sys.exc_info()[1], exc_info=True)
class HTTPSServer(HTTPServer): class ParallelHTTPSServer(ParallelHTTPServer):
"""HTTPS server."""
# These class attributes must be set before creating instance # These class attributes must be set before creating instance
certificate = None certificate = None
@ -98,9 +106,11 @@ class HTTPSServer(HTTPServer):
ciphers = None ciphers = None
certificate_authority = None certificate_authority = None
def __init__(self, address, handler): def __init__(self, address, handler, bind_and_activate=True):
"""Create server by wrapping HTTP socket in an SSL socket.""" """Create server by wrapping HTTP socket in an SSL socket."""
super().__init__(address, handler, bind_and_activate=False)
# Do not bind and activate, as we change the socket
super().__init__(address, handler, False)
self.socket = ssl.wrap_socket( self.socket = ssl.wrap_socket(
self.socket, self.key, self.certificate, server_side=True, self.socket, self.key, self.certificate, server_side=True,
@ -110,18 +120,15 @@ class HTTPSServer(HTTPServer):
ssl_version=self.protocol, ciphers=self.ciphers, ssl_version=self.protocol, ciphers=self.ciphers,
do_handshake_on_connect=False) do_handshake_on_connect=False)
self.server_bind() if bind_and_activate:
self.server_activate() try:
self.server_bind()
self.server_activate()
except BaseException:
self.server_close()
raise
def finish_request(self, request, client_address):
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
def process_request_thread(self, request, client_address):
with self.connections_guard:
return super().process_request_thread(request, client_address)
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
def process_request_thread(self, request, client_address):
try: try:
try: try:
request.do_handshake() request.do_handshake()
@ -135,8 +142,7 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
finally: finally:
self.shutdown_request(request) self.shutdown_request(request)
return return
with self.connections_guard: return super().finish_request(request, client_address)
return super().process_request_thread(request, client_address)
class ServerHandler(wsgiref.simple_server.ServerHandler): class ServerHandler(wsgiref.simple_server.ServerHandler):
@ -197,7 +203,7 @@ def serve(configuration):
# Create collection servers # Create collection servers
servers = {} servers = {}
if configuration.getboolean("server", "ssl"): if configuration.getboolean("server", "ssl"):
server_class = ThreadedHTTPSServer server_class = ParallelHTTPSServer
server_class.certificate = configuration.get("server", "certificate") server_class.certificate = configuration.get("server", "certificate")
server_class.key = configuration.get("server", "key") server_class.key = configuration.get("server", "key")
server_class.certificate_authority = configuration.get( server_class.certificate_authority = configuration.get(
@ -216,7 +222,7 @@ def serve(configuration):
raise RuntimeError("Failed to read SSL %s %r: %s" % raise RuntimeError("Failed to read SSL %s %r: %s" %
(name, filename, e)) from e (name, filename, e)) from e
else: else:
server_class = ThreadedHTTPServer server_class = ParallelHTTPServer
server_class.client_timeout = configuration.getint("server", "timeout") server_class.client_timeout = configuration.getint("server", "timeout")
server_class.max_connections = configuration.getint( server_class.max_connections = configuration.getint(
"server", "max_connections") "server", "max_connections")