From 66fabbead9027495fb526b482606c295d84818eb Mon Sep 17 00:00:00 2001 From: Unrud Date: Wed, 19 Feb 2020 09:50:19 +0100 Subject: [PATCH] Use socket pairs to communicate with client threads --- radicale/config.py | 7 ++ radicale/server.py | 224 +++++++++++++++------------------- radicale/tests/test_server.py | 18 ++- 3 files changed, 121 insertions(+), 128 deletions(-) diff --git a/radicale/config.py b/radicale/config.py index b01c3f4..58d0bf6 100644 --- a/radicale/config.py +++ b/radicale/config.py @@ -371,6 +371,13 @@ class Configuration: break return fconfig[section][option] + def get_source(self, section, option): + """Get the source that provides ``option`` in ``section``.""" + for config, source, _ in reversed(self._configs): + if option in config.get(section, {}): + return source + raise KeyError(section, option) + def sections(self): """List all sections.""" return self._values.keys() diff --git a/radicale/server.py b/radicale/server.py index 190dceb..c2cccfa 100644 --- a/radicale/server.py +++ b/radicale/server.py @@ -22,103 +22,73 @@ Built-in WSGI server. """ -import contextlib import os import select import socket import socketserver import ssl import sys -import threading import wsgiref.simple_server from urllib.parse import unquote -from radicale import Application +from radicale import Application, config from radicale.log import logger -try: - import systemd.daemon -except ImportError: - systemd = None - -HAS_IPV6 = socket.has_ipv6 -if hasattr(socket, "EAI_NONAME"): - EAI_NONAME = socket.EAI_NONAME -else: - HAS_IPV6 = False if hasattr(socket, "EAI_ADDRFAMILY"): - EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY -elif os.name == "nt": - EAI_ADDRFAMILY = None -else: - HAS_IPV6 = False + COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY +elif os.name == "nt" and hasattr(socket, "EAI_NONAME"): + # Windows doesn't have a special error code for this + COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME if hasattr(socket, "IPPROTO_IPV6"): - IPPROTO_IPV6 = socket.IPPROTO_IPV6 + COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6 elif os.name == "nt": - IPPROTO_IPV6 = 41 -else: - HAS_IPV6 = False -if hasattr(socket, "IPV6_V6ONLY"): - IPV6_V6ONLY = socket.IPV6_V6ONLY -elif os.name == "nt": - IPV6_V6ONLY = 27 -else: - HAS_IPV6 = False + # Workaround: https://bugs.python.org/issue29515 + COMPAT_IPPROTO_IPV6 = 41 + + +def format_address(address): + return "[%s]:%d" % address[:2] class ParallelHTTPServer(socketserver.ThreadingMixIn, wsgiref.simple_server.WSGIServer): - # Python 3.6: Wait for child threads (Default in Python >= 3.7) - _block_on_close = True + # We wait for child threads ourself + block_on_close = False - def __init__(self, configuration, address_family, - server_address_or_socket, RequestHandlerClass): + def __init__(self, configuration, family, address, RequestHandlerClass): self.configuration = configuration - self.address_family = address_family - if isinstance(server_address_or_socket, socket.socket): - self._override_socket = server_address_or_socket - server_address = server_address_or_socket.getsockname() - else: - self._override_socket = None - server_address = server_address_or_socket - super().__init__(server_address, RequestHandlerClass) - max_connections = self.configuration.get("server", "max_connections") - if max_connections: - self.connections_guard = threading.BoundedSemaphore( - max_connections) - else: - # use dummy context manager - self.connections_guard = contextlib.ExitStack() + self.address_family = family + super().__init__(address, RequestHandlerClass) + self.client_sockets = set() def server_bind(self): - if self._override_socket is not None: - self.socket = self._override_socket - host, port = self.server_address[:2] - self.server_name = socket.getfqdn(host) - self.server_port = port - self.setup_environ() - return if self.address_family == socket.AF_INET6: # Only allow IPv6 connections to the IPv6 socket - self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) + self.socket.setsockopt(COMPAT_IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) super().server_bind() def get_request(self): # Set timeout for client - socket_, address = super().get_request() + request, client_address = super().get_request() timeout = self.configuration.get("server", "timeout") if timeout: - socket_.settimeout(timeout) - return socket_, address + request.settimeout(timeout) + client_socket, client_socket_out = socket.socketpair() + self.client_sockets.add(client_socket_out) + return request, (*client_address, client_socket) def finish_request_locked(self, request, client_address): return super().finish_request(request, client_address) def finish_request(self, request, client_address): """Don't overwrite this! (Modified by tests.)""" - with self.connections_guard: + *client_address, client_socket = client_address + client_address = tuple(client_address) + try: return self.finish_request_locked(request, client_address) + finally: + client_socket.close() def handle_error(self, request, client_address): if issubclass(sys.exc_info()[0], socket.timeout): @@ -139,13 +109,18 @@ class ParallelHTTPSServer(ParallelHTTPServer): # Test if the files can be read for name, filename in [("certificate", certfile), ("key", keyfile), ("certificate_authority", cafile)]: + type_name = config.DEFAULT_CONFIG_SCHEMA["server"][name][ + "type"].__name__ + source = self.configuration.get_source("server", name) if name == "certificate_authority" and not filename: continue try: open(filename, "r").close() except OSError as e: - raise RuntimeError("Failed to read SSL %s %r: %s" % - (name, filename, e)) from e + raise RuntimeError( + "Invalid %s value for option %r in section %r in %s: %r " + "(%s)" % (type_name, name, "server", source, filename, + e)) from e context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context.load_cert_chain(certfile=certfile, keyfile=keyfile) if cafile: @@ -185,7 +160,7 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler): """HTTP requests handler.""" def log_request(self, code="-", size="-"): - """Disable request logging.""" + pass # Disable request logging. def log_error(self, format_, *args): logger.error("An error occurred during request: %s", format_ % args) @@ -220,7 +195,7 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler): handler.run(self.server.get_app()) -def serve(configuration, shutdown_socket=None): +def serve(configuration, shutdown_socket): """Serve radicale from configuration.""" logger.info("Starting Radicale") # Copy configuration before modifying @@ -228,76 +203,77 @@ def serve(configuration, shutdown_socket=None): configuration.update({"internal": {"internal_server": "True"}}, "server", internal=True) - # Create server sockets - server_addresses_or_sockets = [] # [((host, port) or socket, family)] - if systemd: - listen_fds = systemd.daemon.listen_fds() - else: - listen_fds = [] - if listen_fds: - logger.info("Using socket activation") - for fd in listen_fds: - server_addresses_or_sockets.append((socket.fromfd( - fd, socket.AF_UNIX, socket.SOCK_STREAM), socket.AF_UNIX)) - else: - for address, port in configuration.get("server", "hosts"): - server_addresses_or_sockets.append( - ((address, port), socket.AF_INET)) - if configuration.get("server", "ssl"): - server_class = ParallelHTTPSServer - else: - server_class = ParallelHTTPServer + use_ssl = configuration.get("server", "ssl") + server_class = ParallelHTTPSServer if use_ssl else ParallelHTTPServer application = Application(configuration) servers = {} - for server_address_or_socket, family in server_addresses_or_sockets: - # If family is AF_INET, try to bind sockets for AF_INET and AF_INET6 - bind_successful = False - for family in [family, socket.AF_INET6]: - try: - server = server_class(configuration, family, - server_address_or_socket, RequestHandler) - except OSError as e: - if ((family == socket.AF_INET and HAS_IPV6 or - bind_successful) and isinstance(e, socket.gaierror) and - e.errno in (EAI_NONAME, EAI_ADDRFAMILY)): - # Allow one of AF_INET and AF_INET6 to fail, when - # the address or host don't support the address family. - continue - raise RuntimeError( - "Failed to start server %r: %s" % ( - server_address_or_socket, e)) from e - bind_successful = True - server.set_app(application) - servers[server.socket] = server - logger.info("Listening to %r on port %d%s (%s)", - server.server_name, server.server_port, " using SSL" - if configuration.get("server", "ssl") else "", - family.name) + try: + for address in configuration.get("server", "hosts"): + # Try to bind sockets for IPv4 and IPv6 + possible_families = (socket.AF_INET, socket.AF_INET6) + bind_ok = False + for i, family in enumerate(possible_families): + try: + server = server_class(configuration, family, address, + RequestHandler) + except OSError as e: + if ((bind_ok or i < len(possible_families) - 1) and + isinstance(e, socket.gaierror) and + e.errno in (socket.EAI_NONAME, + COMPAT_EAI_ADDRFAMILY)): + # Ignore unsupported families, only one must work + continue + raise RuntimeError( + "Failed to start server %r: %s" % ( + format_address(address), e)) from e + servers[server.socket] = server + bind_ok = True + server.set_app(application) + logger.info("Listening on %r%s", + format_address(server.server_address), + " with SSL" if use_ssl else "") + assert servers, "no servers started" - # Main loop: wait for requests on any of the servers or program shutdown - sockets = list(servers.keys()) - # Use socket pair to get notified of program shutdown - if shutdown_socket: - sockets.append(shutdown_socket) - select_timeout = None - if os.name == "nt": - # Fallback to busy waiting. (select.select blocks SIGINT on Windows.) - select_timeout = 1.0 - logger.info("Radicale server ready") - - with contextlib.ExitStack() as exit_stack: - for _, server in servers.items(): - exit_stack.callback(server.server_close) + # Mainloop + select_timeout = None + if os.name == "nt": + # Fallback to busy waiting. (select(...) blocks SIGINT on Windows.) + select_timeout = 1.0 + max_connections = configuration.get("server", "max_connections") + logger.info("Radicale server ready") while True: - rlist, _, xlist = select.select( - sockets, [], sockets, select_timeout) + rlist = xlist = [] + # Wait for finished clients + for server in servers.values(): + rlist.extend(server.client_sockets) + # Accept new connections if max_connections is not reached + if max_connections <= 0 or len(rlist) < max_connections: + rlist.extend(servers) + # Use socket to get notified of program shutdown + rlist.append(shutdown_socket) + rlist, _, xlist = select.select(rlist, [], xlist, select_timeout) if xlist: raise RuntimeError("unhandled socket error") + rlist = set(rlist) if shutdown_socket in rlist: logger.info("Stopping Radicale") break + for server in servers.values(): + finished_sockets = server.client_sockets.intersection(rlist) + for s in finished_sockets: + s.close() + server.client_sockets.remove(s) + rlist.remove(s) + if finished_sockets: + server.service_actions() if rlist: - server = servers.get(rlist[0]) + server = servers.get(rlist.pop()) if server: server.handle_request() - server.service_actions() + finally: + # Wait for clients to finish and close servers + for server in servers.values(): + for s in server.client_sockets: + s.recv(1) + s.close() + server.server_close() diff --git a/radicale/tests/test_server.py b/radicale/tests/test_server.py index 317c8a1..f9b2c61 100644 --- a/radicale/tests/test_server.py +++ b/radicale/tests/test_server.py @@ -119,15 +119,25 @@ class TestBaseServerRequests(BaseTest): self.thread.start() self.get("/", check=302) - @pytest.mark.skipif(not server.HAS_IPV6, reason="IPv6 not supported") + def test_bind_fail(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + with pytest.raises(socket.gaierror) as exc_info: + sock.bind(("::1", 0)) + assert exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: + with pytest.raises(socket.gaierror) as exc_info: + sock.bind(("127.0.0.1", 0)) + assert exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY + def test_ipv6(self): with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: - sock.setsockopt(server.IPPROTO_IPV6, server.IPV6_V6ONLY, 1) try: # Find available port sock.bind(("::1", 0)) - except OSError: - pytest.skip("IPv6 not supported") + except socket.gaierror as e: + if e.errno == server.COMPAT_EAI_ADDRFAMILY: + pytest.skip("IPv6 not supported") + raise self.sockname = sock.getsockname()[:2] self.configuration.update({ "server": {"hosts": "[%s]:%d" % self.sockname}}, "test")