Remove forking support

* Third-party plugins have to be fork-safe
* Not supported on Windows
This commit is contained in:
Unrud 2020-02-19 09:50:02 +01:00
parent 3b99d64935
commit 698980d7be
2 changed files with 19 additions and 87 deletions

View File

@ -24,21 +24,15 @@ Log messages are sent to the first available target of:
- systemd-journald - systemd-journald
- stderr - stderr
The logger is thread-safe and fork-safe.
""" """
import contextlib import contextlib
import io import io
import logging import logging
import multiprocessing
import os import os
import sys import sys
import tempfile
import threading import threading
from radicale import pathutils
try: try:
import systemd.journal import systemd.journal
except ImportError: except ImportError:
@ -64,14 +58,10 @@ class IdentLogRecordFactory:
def __init__(self, upstream_factory): def __init__(self, upstream_factory):
self.upstream_factory = upstream_factory self.upstream_factory = upstream_factory
self.main_pid = os.getpid()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
record = self.upstream_factory(*args, **kwargs) record = self.upstream_factory(*args, **kwargs)
pid = os.getpid() ident = "%x" % os.getpid()
ident = "%x" % self.main_pid
if pid != self.main_pid:
ident += "%+x" % (pid - self.main_pid)
main_thread = threading.main_thread() main_thread = threading.main_thread()
current_thread = threading.current_thread() current_thread = threading.current_thread()
if current_thread.name and main_thread != current_thread: if current_thread.name and main_thread != current_thread:
@ -80,27 +70,6 @@ class IdentLogRecordFactory:
return record return record
class RwLockWrapper():
def __init__(self):
self._file = tempfile.NamedTemporaryFile()
self._lock = pathutils.RwLock(self._file.name)
self._cm = None
def acquire(self, blocking=True):
assert self._cm is None
if not blocking:
raise NotImplementedError
cm = self._lock.acquire("w")
cm.__enter__()
self._cm = cm
def release(self):
assert self._cm is not None
self._cm.__exit__(None, None, None)
self._cm = None
class ThreadStreamsHandler(logging.Handler): class ThreadStreamsHandler(logging.Handler):
terminator = "\n" terminator = "\n"
@ -111,13 +80,6 @@ class ThreadStreamsHandler(logging.Handler):
self.fallback_stream = fallback_stream self.fallback_stream = fallback_stream
self.fallback_handler = fallback_handler self.fallback_handler = fallback_handler
def createLock(self):
try:
self.lock = multiprocessing.Lock()
except Exception:
# HACK: Workaround for Android
self.lock = RwLockWrapper()
def setFormatter(self, fmt): def setFormatter(self, fmt):
super().setFormatter(fmt) super().setFormatter(fmt)
self.fallback_handler.setFormatter(fmt) self.fallback_handler.setFormatter(fmt)

View File

@ -20,12 +20,9 @@
""" """
Built-in WSGI server. Built-in WSGI server.
Uses forking on POSIX to overcome Python's GIL.
""" """
import contextlib import contextlib
import multiprocessing
import os import os
import select import select
import socket import socket
@ -44,18 +41,6 @@ try:
except ImportError: except ImportError:
systemd = None systemd = None
USE_FORKING = hasattr(os, "fork")
try:
multiprocessing.BoundedSemaphore()
except Exception:
# HACK: Workaround for Android
USE_FORKING = False
if USE_FORKING:
ParallelizationMixIn = socketserver.ForkingMixIn
else:
ParallelizationMixIn = socketserver.ThreadingMixIn
HAS_IPV6 = socket.has_ipv6 HAS_IPV6 = socket.has_ipv6
if hasattr(socket, "EAI_NONAME"): if hasattr(socket, "EAI_NONAME"):
EAI_NONAME = socket.EAI_NONAME EAI_NONAME = socket.EAI_NONAME
@ -81,10 +66,10 @@ else:
HAS_IPV6 = False HAS_IPV6 = False
class ParallelHTTPServer(ParallelizationMixIn, class ParallelHTTPServer(socketserver.ThreadingMixIn,
wsgiref.simple_server.WSGIServer): wsgiref.simple_server.WSGIServer):
# Python 3.6: Wait for child processes/threads (Default in Python >= 3.7) # Python 3.6: Wait for child threads (Default in Python >= 3.7)
_block_on_close = True _block_on_close = True
def __init__(self, configuration, address_family, def __init__(self, configuration, address_family,
@ -92,39 +77,32 @@ class ParallelHTTPServer(ParallelizationMixIn,
self.configuration = configuration self.configuration = configuration
self.address_family = address_family self.address_family = address_family
if isinstance(server_address_or_socket, socket.socket): if isinstance(server_address_or_socket, socket.socket):
override_socket = server_address_or_socket self._override_socket = server_address_or_socket
server_address = override_socket.getsockname() server_address = server_address_or_socket.getsockname()
else: else:
override_socket = None self._override_socket = None
server_address = server_address_or_socket server_address = server_address_or_socket
super().__init__(server_address, RequestHandlerClass, super().__init__(server_address, RequestHandlerClass)
bind_and_activate=False)
if USE_FORKING:
sema_class = multiprocessing.BoundedSemaphore
else:
sema_class = threading.BoundedSemaphore
max_connections = self.configuration.get("server", "max_connections") max_connections = self.configuration.get("server", "max_connections")
if max_connections: if max_connections:
self.connections_guard = sema_class(max_connections) self.connections_guard = threading.BoundedSemaphore(
max_connections)
else: else:
# use dummy context manager # use dummy context manager
self.connections_guard = contextlib.ExitStack() self.connections_guard = contextlib.ExitStack()
if override_socket:
self.socket = override_socket def server_bind(self):
if self._override_socket is not None:
self.socket = self._override_socket
host, port = self.server_address[:2] host, port = self.server_address[:2]
self.server_name = socket.getfqdn(host) self.server_name = socket.getfqdn(host)
self.server_port = port self.server_port = port
self.setup_environ() self.setup_environ()
return return
try: if self.address_family == socket.AF_INET6:
if self.address_family == socket.AF_INET6: # Only allow IPv6 connections to the IPv6 socket
# Only allow IPv6 connections to the IPv6 socket self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)
self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) super().server_bind()
self.server_bind()
self.server_activate()
except BaseException:
self.server_close()
raise
def get_request(self): def get_request(self):
# Set timeout for client # Set timeout for client
@ -134,14 +112,6 @@ class ParallelHTTPServer(ParallelizationMixIn,
socket_.settimeout(timeout) socket_.settimeout(timeout)
return socket_, address return socket_, address
def process_request(self, request, client_address):
try:
return super().process_request(request, client_address)
finally:
# Modify OpenSSL's RNG state, in case process forked
# See https://docs.python.org/3.7/library/ssl.html#multi-processing
ssl.RAND_add(os.urandom(8), 0.0)
def finish_request_locked(self, request, client_address): def finish_request_locked(self, request, client_address):
return super().finish_request(request, client_address) return super().finish_request(request, client_address)
@ -280,13 +250,12 @@ def serve(configuration, shutdown_socket=None):
application = Application(configuration) application = Application(configuration)
servers = {} servers = {}
for server_address_or_socket, family in server_addresses_or_sockets: for server_address_or_socket, family in server_addresses_or_sockets:
# If familiy is AF_INET, try to bind sockets for AF_INET and AF_INET6 # If family is AF_INET, try to bind sockets for AF_INET and AF_INET6
bind_successful = False bind_successful = False
for family in [family, socket.AF_INET6]: for family in [family, socket.AF_INET6]:
try: try:
server = server_class(configuration, family, server = server_class(configuration, family,
server_address_or_socket, RequestHandler) server_address_or_socket, RequestHandler)
server.set_app(application)
except OSError as e: except OSError as e:
if ((family == socket.AF_INET and HAS_IPV6 or if ((family == socket.AF_INET and HAS_IPV6 or
bind_successful) and isinstance(e, socket.gaierror) and bind_successful) and isinstance(e, socket.gaierror) and
@ -298,6 +267,7 @@ def serve(configuration, shutdown_socket=None):
"Failed to start server %r: %s" % ( "Failed to start server %r: %s" % (
server_address_or_socket, e)) from e server_address_or_socket, e)) from e
bind_successful = True bind_successful = True
server.set_app(application)
servers[server.socket] = server servers[server.socket] = server
logger.info("Listening to %r on port %d%s (%s)", logger.info("Listening to %r on port %d%s (%s)",
server.server_name, server.server_port, " using SSL" server.server_name, server.server_port, " using SSL"