Use socket pairs to communicate with client threads

This commit is contained in:
Unrud 2020-02-19 09:50:19 +01:00
parent 698980d7be
commit 66fabbead9
3 changed files with 121 additions and 128 deletions

View File

@ -371,6 +371,13 @@ class Configuration:
break break
return fconfig[section][option] return fconfig[section][option]
def get_source(self, section, option):
"""Get the source that provides ``option`` in ``section``."""
for config, source, _ in reversed(self._configs):
if option in config.get(section, {}):
return source
raise KeyError(section, option)
def sections(self): def sections(self):
"""List all sections.""" """List all sections."""
return self._values.keys() return self._values.keys()

View File

@ -22,103 +22,73 @@ Built-in WSGI server.
""" """
import contextlib
import os import os
import select import select
import socket import socket
import socketserver import socketserver
import ssl import ssl
import sys import sys
import threading
import wsgiref.simple_server import wsgiref.simple_server
from urllib.parse import unquote from urllib.parse import unquote
from radicale import Application from radicale import Application, config
from radicale.log import logger from radicale.log import logger
try:
import systemd.daemon
except ImportError:
systemd = None
HAS_IPV6 = socket.has_ipv6
if hasattr(socket, "EAI_NONAME"):
EAI_NONAME = socket.EAI_NONAME
else:
HAS_IPV6 = False
if hasattr(socket, "EAI_ADDRFAMILY"): if hasattr(socket, "EAI_ADDRFAMILY"):
EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY
elif os.name == "nt": elif os.name == "nt" and hasattr(socket, "EAI_NONAME"):
EAI_ADDRFAMILY = None # Windows doesn't have a special error code for this
else: COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME
HAS_IPV6 = False
if hasattr(socket, "IPPROTO_IPV6"): if hasattr(socket, "IPPROTO_IPV6"):
IPPROTO_IPV6 = socket.IPPROTO_IPV6 COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6
elif os.name == "nt": elif os.name == "nt":
IPPROTO_IPV6 = 41 # Workaround: https://bugs.python.org/issue29515
else: COMPAT_IPPROTO_IPV6 = 41
HAS_IPV6 = False
if hasattr(socket, "IPV6_V6ONLY"):
IPV6_V6ONLY = socket.IPV6_V6ONLY def format_address(address):
elif os.name == "nt": return "[%s]:%d" % address[:2]
IPV6_V6ONLY = 27
else:
HAS_IPV6 = False
class ParallelHTTPServer(socketserver.ThreadingMixIn, class ParallelHTTPServer(socketserver.ThreadingMixIn,
wsgiref.simple_server.WSGIServer): wsgiref.simple_server.WSGIServer):
# Python 3.6: Wait for child threads (Default in Python >= 3.7) # We wait for child threads ourself
_block_on_close = True block_on_close = False
def __init__(self, configuration, address_family, def __init__(self, configuration, family, address, RequestHandlerClass):
server_address_or_socket, RequestHandlerClass):
self.configuration = configuration self.configuration = configuration
self.address_family = address_family self.address_family = family
if isinstance(server_address_or_socket, socket.socket): super().__init__(address, RequestHandlerClass)
self._override_socket = server_address_or_socket self.client_sockets = set()
server_address = server_address_or_socket.getsockname()
else:
self._override_socket = None
server_address = server_address_or_socket
super().__init__(server_address, RequestHandlerClass)
max_connections = self.configuration.get("server", "max_connections")
if max_connections:
self.connections_guard = threading.BoundedSemaphore(
max_connections)
else:
# use dummy context manager
self.connections_guard = contextlib.ExitStack()
def server_bind(self): def server_bind(self):
if self._override_socket is not None:
self.socket = self._override_socket
host, port = self.server_address[:2]
self.server_name = socket.getfqdn(host)
self.server_port = port
self.setup_environ()
return
if self.address_family == socket.AF_INET6: if self.address_family == socket.AF_INET6:
# Only allow IPv6 connections to the IPv6 socket # Only allow IPv6 connections to the IPv6 socket
self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) self.socket.setsockopt(COMPAT_IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
super().server_bind() super().server_bind()
def get_request(self): def get_request(self):
# Set timeout for client # Set timeout for client
socket_, address = super().get_request() request, client_address = super().get_request()
timeout = self.configuration.get("server", "timeout") timeout = self.configuration.get("server", "timeout")
if timeout: if timeout:
socket_.settimeout(timeout) request.settimeout(timeout)
return socket_, address client_socket, client_socket_out = socket.socketpair()
self.client_sockets.add(client_socket_out)
return request, (*client_address, client_socket)
def finish_request_locked(self, request, client_address): def finish_request_locked(self, request, client_address):
return super().finish_request(request, client_address) return super().finish_request(request, client_address)
def finish_request(self, request, client_address): def finish_request(self, request, client_address):
"""Don't overwrite this! (Modified by tests.)""" """Don't overwrite this! (Modified by tests.)"""
with self.connections_guard: *client_address, client_socket = client_address
client_address = tuple(client_address)
try:
return self.finish_request_locked(request, client_address) return self.finish_request_locked(request, client_address)
finally:
client_socket.close()
def handle_error(self, request, client_address): def handle_error(self, request, client_address):
if issubclass(sys.exc_info()[0], socket.timeout): if issubclass(sys.exc_info()[0], socket.timeout):
@ -139,13 +109,18 @@ class ParallelHTTPSServer(ParallelHTTPServer):
# Test if the files can be read # Test if the files can be read
for name, filename in [("certificate", certfile), ("key", keyfile), for name, filename in [("certificate", certfile), ("key", keyfile),
("certificate_authority", cafile)]: ("certificate_authority", cafile)]:
type_name = config.DEFAULT_CONFIG_SCHEMA["server"][name][
"type"].__name__
source = self.configuration.get_source("server", name)
if name == "certificate_authority" and not filename: if name == "certificate_authority" and not filename:
continue continue
try: try:
open(filename, "r").close() open(filename, "r").close()
except OSError as e: except OSError as e:
raise RuntimeError("Failed to read SSL %s %r: %s" % raise RuntimeError(
(name, filename, e)) from e "Invalid %s value for option %r in section %r in %s: %r "
"(%s)" % (type_name, name, "server", source, filename,
e)) from e
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=certfile, keyfile=keyfile) context.load_cert_chain(certfile=certfile, keyfile=keyfile)
if cafile: if cafile:
@ -185,7 +160,7 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
"""HTTP requests handler.""" """HTTP requests handler."""
def log_request(self, code="-", size="-"): def log_request(self, code="-", size="-"):
"""Disable request logging.""" pass # Disable request logging.
def log_error(self, format_, *args): def log_error(self, format_, *args):
logger.error("An error occurred during request: %s", format_ % args) logger.error("An error occurred during request: %s", format_ % args)
@ -220,7 +195,7 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
handler.run(self.server.get_app()) handler.run(self.server.get_app())
def serve(configuration, shutdown_socket=None): def serve(configuration, shutdown_socket):
"""Serve radicale from configuration.""" """Serve radicale from configuration."""
logger.info("Starting Radicale") logger.info("Starting Radicale")
# Copy configuration before modifying # Copy configuration before modifying
@ -228,76 +203,77 @@ def serve(configuration, shutdown_socket=None):
configuration.update({"internal": {"internal_server": "True"}}, "server", configuration.update({"internal": {"internal_server": "True"}}, "server",
internal=True) internal=True)
# Create server sockets use_ssl = configuration.get("server", "ssl")
server_addresses_or_sockets = [] # [((host, port) or socket, family)] server_class = ParallelHTTPSServer if use_ssl else ParallelHTTPServer
if systemd:
listen_fds = systemd.daemon.listen_fds()
else:
listen_fds = []
if listen_fds:
logger.info("Using socket activation")
for fd in listen_fds:
server_addresses_or_sockets.append((socket.fromfd(
fd, socket.AF_UNIX, socket.SOCK_STREAM), socket.AF_UNIX))
else:
for address, port in configuration.get("server", "hosts"):
server_addresses_or_sockets.append(
((address, port), socket.AF_INET))
if configuration.get("server", "ssl"):
server_class = ParallelHTTPSServer
else:
server_class = ParallelHTTPServer
application = Application(configuration) application = Application(configuration)
servers = {} servers = {}
for server_address_or_socket, family in server_addresses_or_sockets:
# If family is AF_INET, try to bind sockets for AF_INET and AF_INET6
bind_successful = False
for family in [family, socket.AF_INET6]:
try: try:
server = server_class(configuration, family, for address in configuration.get("server", "hosts"):
server_address_or_socket, RequestHandler) # Try to bind sockets for IPv4 and IPv6
possible_families = (socket.AF_INET, socket.AF_INET6)
bind_ok = False
for i, family in enumerate(possible_families):
try:
server = server_class(configuration, family, address,
RequestHandler)
except OSError as e: except OSError as e:
if ((family == socket.AF_INET and HAS_IPV6 or if ((bind_ok or i < len(possible_families) - 1) and
bind_successful) and isinstance(e, socket.gaierror) and isinstance(e, socket.gaierror) and
e.errno in (EAI_NONAME, EAI_ADDRFAMILY)): e.errno in (socket.EAI_NONAME,
# Allow one of AF_INET and AF_INET6 to fail, when COMPAT_EAI_ADDRFAMILY)):
# the address or host don't support the address family. # Ignore unsupported families, only one must work
continue continue
raise RuntimeError( raise RuntimeError(
"Failed to start server %r: %s" % ( "Failed to start server %r: %s" % (
server_address_or_socket, e)) from e format_address(address), e)) from e
bind_successful = True
server.set_app(application)
servers[server.socket] = server servers[server.socket] = server
logger.info("Listening to %r on port %d%s (%s)", bind_ok = True
server.server_name, server.server_port, " using SSL" server.set_app(application)
if configuration.get("server", "ssl") else "", logger.info("Listening on %r%s",
family.name) format_address(server.server_address),
" with SSL" if use_ssl else "")
assert servers, "no servers started"
# Main loop: wait for requests on any of the servers or program shutdown # Mainloop
sockets = list(servers.keys())
# Use socket pair to get notified of program shutdown
if shutdown_socket:
sockets.append(shutdown_socket)
select_timeout = None select_timeout = None
if os.name == "nt": if os.name == "nt":
# Fallback to busy waiting. (select.select blocks SIGINT on Windows.) # Fallback to busy waiting. (select(...) blocks SIGINT on Windows.)
select_timeout = 1.0 select_timeout = 1.0
max_connections = configuration.get("server", "max_connections")
logger.info("Radicale server ready") logger.info("Radicale server ready")
with contextlib.ExitStack() as exit_stack:
for _, server in servers.items():
exit_stack.callback(server.server_close)
while True: while True:
rlist, _, xlist = select.select( rlist = xlist = []
sockets, [], sockets, select_timeout) # Wait for finished clients
for server in servers.values():
rlist.extend(server.client_sockets)
# Accept new connections if max_connections is not reached
if max_connections <= 0 or len(rlist) < max_connections:
rlist.extend(servers)
# Use socket to get notified of program shutdown
rlist.append(shutdown_socket)
rlist, _, xlist = select.select(rlist, [], xlist, select_timeout)
if xlist: if xlist:
raise RuntimeError("unhandled socket error") raise RuntimeError("unhandled socket error")
rlist = set(rlist)
if shutdown_socket in rlist: if shutdown_socket in rlist:
logger.info("Stopping Radicale") logger.info("Stopping Radicale")
break break
for server in servers.values():
finished_sockets = server.client_sockets.intersection(rlist)
for s in finished_sockets:
s.close()
server.client_sockets.remove(s)
rlist.remove(s)
if finished_sockets:
server.service_actions()
if rlist: if rlist:
server = servers.get(rlist[0]) server = servers.get(rlist.pop())
if server: if server:
server.handle_request() server.handle_request()
server.service_actions() finally:
# Wait for clients to finish and close servers
for server in servers.values():
for s in server.client_sockets:
s.recv(1)
s.close()
server.server_close()

View File

@ -119,15 +119,25 @@ class TestBaseServerRequests(BaseTest):
self.thread.start() self.thread.start()
self.get("/", check=302) self.get("/", check=302)
@pytest.mark.skipif(not server.HAS_IPV6, reason="IPv6 not supported") def test_bind_fail(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
with pytest.raises(socket.gaierror) as exc_info:
sock.bind(("::1", 0))
assert exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
with pytest.raises(socket.gaierror) as exc_info:
sock.bind(("127.0.0.1", 0))
assert exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY
def test_ipv6(self): def test_ipv6(self):
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
sock.setsockopt(server.IPPROTO_IPV6, server.IPV6_V6ONLY, 1)
try: try:
# Find available port # Find available port
sock.bind(("::1", 0)) sock.bind(("::1", 0))
except OSError: except socket.gaierror as e:
if e.errno == server.COMPAT_EAI_ADDRFAMILY:
pytest.skip("IPv6 not supported") pytest.skip("IPv6 not supported")
raise
self.sockname = sock.getsockname()[:2] self.sockname = sock.getsockname()[:2]
self.configuration.update({ self.configuration.update({
"server": {"hosts": "[%s]:%d" % self.sockname}}, "test") "server": {"hosts": "[%s]:%d" % self.sockname}}, "test")