Remove forking support

* Third-party plugins have to be fork-safe
* Not supported on Windows
This commit is contained in:
Unrud 2020-02-19 09:50:02 +01:00
parent 3b99d64935
commit 698980d7be
2 changed files with 19 additions and 87 deletions

View File

@ -24,21 +24,15 @@ Log messages are sent to the first available target of:
- systemd-journald
- stderr
The logger is thread-safe and fork-safe.
"""
import contextlib
import io
import logging
import multiprocessing
import os
import sys
import tempfile
import threading
from radicale import pathutils
try:
import systemd.journal
except ImportError:
@ -64,14 +58,10 @@ class IdentLogRecordFactory:
def __init__(self, upstream_factory):
self.upstream_factory = upstream_factory
self.main_pid = os.getpid()
def __call__(self, *args, **kwargs):
record = self.upstream_factory(*args, **kwargs)
pid = os.getpid()
ident = "%x" % self.main_pid
if pid != self.main_pid:
ident += "%+x" % (pid - self.main_pid)
ident = "%x" % os.getpid()
main_thread = threading.main_thread()
current_thread = threading.current_thread()
if current_thread.name and main_thread != current_thread:
@ -80,27 +70,6 @@ class IdentLogRecordFactory:
return record
class RwLockWrapper():
def __init__(self):
self._file = tempfile.NamedTemporaryFile()
self._lock = pathutils.RwLock(self._file.name)
self._cm = None
def acquire(self, blocking=True):
assert self._cm is None
if not blocking:
raise NotImplementedError
cm = self._lock.acquire("w")
cm.__enter__()
self._cm = cm
def release(self):
assert self._cm is not None
self._cm.__exit__(None, None, None)
self._cm = None
class ThreadStreamsHandler(logging.Handler):
terminator = "\n"
@ -111,13 +80,6 @@ class ThreadStreamsHandler(logging.Handler):
self.fallback_stream = fallback_stream
self.fallback_handler = fallback_handler
def createLock(self):
try:
self.lock = multiprocessing.Lock()
except Exception:
# HACK: Workaround for Android
self.lock = RwLockWrapper()
def setFormatter(self, fmt):
super().setFormatter(fmt)
self.fallback_handler.setFormatter(fmt)

View File

@ -20,12 +20,9 @@
"""
Built-in WSGI server.
Uses forking on POSIX to overcome Python's GIL.
"""
import contextlib
import multiprocessing
import os
import select
import socket
@ -44,18 +41,6 @@ try:
except ImportError:
systemd = None
USE_FORKING = hasattr(os, "fork")
try:
multiprocessing.BoundedSemaphore()
except Exception:
# HACK: Workaround for Android
USE_FORKING = False
if USE_FORKING:
ParallelizationMixIn = socketserver.ForkingMixIn
else:
ParallelizationMixIn = socketserver.ThreadingMixIn
HAS_IPV6 = socket.has_ipv6
if hasattr(socket, "EAI_NONAME"):
EAI_NONAME = socket.EAI_NONAME
@ -81,10 +66,10 @@ else:
HAS_IPV6 = False
class ParallelHTTPServer(ParallelizationMixIn,
class ParallelHTTPServer(socketserver.ThreadingMixIn,
wsgiref.simple_server.WSGIServer):
# Python 3.6: Wait for child processes/threads (Default in Python >= 3.7)
# Python 3.6: Wait for child threads (Default in Python >= 3.7)
_block_on_close = True
def __init__(self, configuration, address_family,
@ -92,39 +77,32 @@ class ParallelHTTPServer(ParallelizationMixIn,
self.configuration = configuration
self.address_family = address_family
if isinstance(server_address_or_socket, socket.socket):
override_socket = server_address_or_socket
server_address = override_socket.getsockname()
self._override_socket = server_address_or_socket
server_address = server_address_or_socket.getsockname()
else:
override_socket = None
self._override_socket = None
server_address = server_address_or_socket
super().__init__(server_address, RequestHandlerClass,
bind_and_activate=False)
if USE_FORKING:
sema_class = multiprocessing.BoundedSemaphore
else:
sema_class = threading.BoundedSemaphore
super().__init__(server_address, RequestHandlerClass)
max_connections = self.configuration.get("server", "max_connections")
if max_connections:
self.connections_guard = sema_class(max_connections)
self.connections_guard = threading.BoundedSemaphore(
max_connections)
else:
# use dummy context manager
self.connections_guard = contextlib.ExitStack()
if override_socket:
self.socket = override_socket
def server_bind(self):
if self._override_socket is not None:
self.socket = self._override_socket
host, port = self.server_address[:2]
self.server_name = socket.getfqdn(host)
self.server_port = port
self.setup_environ()
return
try:
if self.address_family == socket.AF_INET6:
# Only allow IPv6 connections to the IPv6 socket
self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)
self.server_bind()
self.server_activate()
except BaseException:
self.server_close()
raise
if self.address_family == socket.AF_INET6:
# Only allow IPv6 connections to the IPv6 socket
self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)
super().server_bind()
def get_request(self):
# Set timeout for client
@ -134,14 +112,6 @@ class ParallelHTTPServer(ParallelizationMixIn,
socket_.settimeout(timeout)
return socket_, address
def process_request(self, request, client_address):
try:
return super().process_request(request, client_address)
finally:
# Modify OpenSSL's RNG state, in case process forked
# See https://docs.python.org/3.7/library/ssl.html#multi-processing
ssl.RAND_add(os.urandom(8), 0.0)
def finish_request_locked(self, request, client_address):
return super().finish_request(request, client_address)
@ -280,13 +250,12 @@ def serve(configuration, shutdown_socket=None):
application = Application(configuration)
servers = {}
for server_address_or_socket, family in server_addresses_or_sockets:
# If familiy is AF_INET, try to bind sockets for AF_INET and AF_INET6
# If family is AF_INET, try to bind sockets for AF_INET and AF_INET6
bind_successful = False
for family in [family, socket.AF_INET6]:
try:
server = server_class(configuration, family,
server_address_or_socket, RequestHandler)
server.set_app(application)
except OSError as e:
if ((family == socket.AF_INET and HAS_IPV6 or
bind_successful) and isinstance(e, socket.gaierror) and
@ -298,6 +267,7 @@ def serve(configuration, shutdown_socket=None):
"Failed to start server %r: %s" % (
server_address_or_socket, e)) from e
bind_successful = True
server.set_app(application)
servers[server.socket] = server
logger.info("Listening to %r on port %d%s (%s)",
server.server_name, server.server_port, " using SSL"