diff --git a/radicale/server.py b/radicale/server.py index 64a9a42..d45749f 100644 --- a/radicale/server.py +++ b/radicale/server.py @@ -87,32 +87,43 @@ class ParallelHTTPServer(ParallelizationMixIn, # wait for child processes/threads _block_on_close = True - # These class attributes must be set before creating instance - client_timeout = None - max_connections = None - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, configuration, address_family, + server_address_or_socket, RequestHandlerClass): + self.configuration = configuration + self.address_family = address_family + if isinstance(server_address_or_socket, socket.socket): + override_socket = server_address_or_socket + server_address = override_socket.getsockname() + else: + override_socket = None + server_address = server_address_or_socket + super().__init__(server_address, RequestHandlerClass, + bind_and_activate=False) if USE_FORKING: sema_class = multiprocessing.BoundedSemaphore else: sema_class = threading.BoundedSemaphore - if self.max_connections: - self.connections_guard = sema_class(self.max_connections) + max_connections = self.configuration.get("server", "max_connections") + if max_connections: + self.connections_guard = sema_class(max_connections) else: # use dummy context manager self.connections_guard = contextlib.ExitStack() - - def server_bind(self): - if isinstance(self.server_address, socket.socket): - # Socket activation - self.socket = self.server_address - self.server_address = self.socket.getsockname() + if override_socket: + self.socket = override_socket host, port = self.server_address[:2] self.server_name = socket.getfqdn(host) self.server_port = port self.setup_environ() return + try: + self.server_bind() + self.server_activate() + except BaseException: + self.server_close() + raise + + def server_bind(self): try: super().server_bind() except socket.gaierror as e: @@ -129,8 +140,9 @@ class ParallelHTTPServer(ParallelizationMixIn, def get_request(self): # Set timeout for client socket_, address = super().get_request() - if self.client_timeout: - socket_.settimeout(self.client_timeout) + timeout = self.configuration.get("server", "timeout") + if timeout: + socket_.settimeout(timeout) return socket_, address def process_request(self, request, client_address): @@ -159,18 +171,26 @@ class ParallelHTTPServer(ParallelizationMixIn, class ParallelHTTPSServer(ParallelHTTPServer): - # These class attributes must be set before creating instance - certificate = None - key = None - certificate_authority = None - def server_bind(self): super().server_bind() # Wrap the TCP socket in an SSL socket + certfile = self.configuration.get("server", "certificate") + keyfile = self.configuration.get("server", "key") + cafile = self.configuration.get("server", "certificate_authority") + # Test if the files can be read + for name, filename in [("certificate", certfile), ("key", keyfile), + ("certificate_authority", cafile)]: + if name == "certificate_authority" and not filename: + continue + try: + open(filename, "r").close() + except OSError as e: + raise RuntimeError("Failed to read SSL %s %r: %s" % + (name, filename, e)) from e context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - context.load_cert_chain(certfile=self.certificate, keyfile=self.key) - if self.certificate_authority: - context.load_verify_locations(cafile=self.certificate_authority) + context.load_cert_chain(certfile=certfile, keyfile=keyfile) + if cafile: + context.load_verify_locations(cafile=cafile) context.verify_mode = ssl.CERT_REQUIRED self.socket = context.wrap_socket( self.socket, server_side=True, do_handshake_on_connect=False) @@ -249,57 +269,36 @@ def serve(configuration, shutdown_socket=None): configuration.update({"internal": {"internal_server": "True"}}, "server", internal=True) - # Create collection servers - servers = {} - if configuration.get("server", "ssl"): - server_class = ParallelHTTPSServer - else: - server_class = ParallelHTTPServer - - class ServerCopy(server_class): - """Copy, avoids overriding the original class attributes.""" - ServerCopy.client_timeout = configuration.get("server", "timeout") - ServerCopy.max_connections = configuration.get("server", "max_connections") - if configuration.get("server", "ssl"): - ServerCopy.certificate = configuration.get("server", "certificate") - ServerCopy.key = configuration.get("server", "key") - ServerCopy.certificate_authority = configuration.get( - "server", "certificate_authority") - # Test if the SSL files can be read - for name in ["certificate", "key"] + ( - ["certificate_authority"] - if ServerCopy.certificate_authority else []): - filename = getattr(ServerCopy, name) - try: - open(filename, "r").close() - except OSError as e: - raise RuntimeError("Failed to read SSL %s %r: %s" % - (name, filename, e)) from e - + # Create server sockets + server_addresses_or_sockets = [] # [((host, port) or socket, family)] if systemd: listen_fds = systemd.daemon.listen_fds() else: listen_fds = [] - - server_addresses = [] if listen_fds: logger.info("Using socket activation") - ServerCopy.address_family = socket.AF_UNIX for fd in listen_fds: - server_addresses.append(socket.fromfd( - fd, ServerCopy.address_family, ServerCopy.socket_type)) + server_addresses_or_sockets.append((socket.fromfd( + fd, socket.AF_UNIX, socket.SOCK_STREAM), socket.AF_UNIX)) else: for address, port in configuration.get("server", "hosts"): - server_addresses.append((address, port)) - + server_addresses_or_sockets.append( + ((address, port), socket.AF_INET)) + if configuration.get("server", "ssl"): + server_class = ParallelHTTPSServer + else: + server_class = ParallelHTTPServer application = Application(configuration) - for server_address in server_addresses: + servers = {} + for server_address_or_socket, family in server_addresses_or_sockets: try: - server = ServerCopy(server_address, RequestHandler) + server = server_class(configuration, family, + server_address_or_socket, RequestHandler) server.set_app(application) except OSError as e: raise RuntimeError( - "Failed to start server %r: %s" % (server_address, e)) from e + "Failed to start server %r: %s" % ( + server_address_or_socket, e)) from e servers[server.socket] = server logger.info("Listening to %r on port %d%s", server.server_name, server.server_port, " using SSL"