Use socket pairs to communicate with client threads
This commit is contained in:
parent
698980d7be
commit
66fabbead9
@ -371,6 +371,13 @@ class Configuration:
|
|||||||
break
|
break
|
||||||
return fconfig[section][option]
|
return fconfig[section][option]
|
||||||
|
|
||||||
|
def get_source(self, section, option):
|
||||||
|
"""Get the source that provides ``option`` in ``section``."""
|
||||||
|
for config, source, _ in reversed(self._configs):
|
||||||
|
if option in config.get(section, {}):
|
||||||
|
return source
|
||||||
|
raise KeyError(section, option)
|
||||||
|
|
||||||
def sections(self):
|
def sections(self):
|
||||||
"""List all sections."""
|
"""List all sections."""
|
||||||
return self._values.keys()
|
return self._values.keys()
|
||||||
|
@ -22,103 +22,73 @@ Built-in WSGI server.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import os
|
import os
|
||||||
import select
|
import select
|
||||||
import socket
|
import socket
|
||||||
import socketserver
|
import socketserver
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import wsgiref.simple_server
|
import wsgiref.simple_server
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from radicale import Application
|
from radicale import Application, config
|
||||||
from radicale.log import logger
|
from radicale.log import logger
|
||||||
|
|
||||||
try:
|
|
||||||
import systemd.daemon
|
|
||||||
except ImportError:
|
|
||||||
systemd = None
|
|
||||||
|
|
||||||
HAS_IPV6 = socket.has_ipv6
|
|
||||||
if hasattr(socket, "EAI_NONAME"):
|
|
||||||
EAI_NONAME = socket.EAI_NONAME
|
|
||||||
else:
|
|
||||||
HAS_IPV6 = False
|
|
||||||
if hasattr(socket, "EAI_ADDRFAMILY"):
|
if hasattr(socket, "EAI_ADDRFAMILY"):
|
||||||
EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY
|
COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY
|
||||||
elif os.name == "nt":
|
elif os.name == "nt" and hasattr(socket, "EAI_NONAME"):
|
||||||
EAI_ADDRFAMILY = None
|
# Windows doesn't have a special error code for this
|
||||||
else:
|
COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME
|
||||||
HAS_IPV6 = False
|
|
||||||
if hasattr(socket, "IPPROTO_IPV6"):
|
if hasattr(socket, "IPPROTO_IPV6"):
|
||||||
IPPROTO_IPV6 = socket.IPPROTO_IPV6
|
COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6
|
||||||
elif os.name == "nt":
|
elif os.name == "nt":
|
||||||
IPPROTO_IPV6 = 41
|
# Workaround: https://bugs.python.org/issue29515
|
||||||
else:
|
COMPAT_IPPROTO_IPV6 = 41
|
||||||
HAS_IPV6 = False
|
|
||||||
if hasattr(socket, "IPV6_V6ONLY"):
|
|
||||||
IPV6_V6ONLY = socket.IPV6_V6ONLY
|
def format_address(address):
|
||||||
elif os.name == "nt":
|
return "[%s]:%d" % address[:2]
|
||||||
IPV6_V6ONLY = 27
|
|
||||||
else:
|
|
||||||
HAS_IPV6 = False
|
|
||||||
|
|
||||||
|
|
||||||
class ParallelHTTPServer(socketserver.ThreadingMixIn,
|
class ParallelHTTPServer(socketserver.ThreadingMixIn,
|
||||||
wsgiref.simple_server.WSGIServer):
|
wsgiref.simple_server.WSGIServer):
|
||||||
|
|
||||||
# Python 3.6: Wait for child threads (Default in Python >= 3.7)
|
# We wait for child threads ourself
|
||||||
_block_on_close = True
|
block_on_close = False
|
||||||
|
|
||||||
def __init__(self, configuration, address_family,
|
def __init__(self, configuration, family, address, RequestHandlerClass):
|
||||||
server_address_or_socket, RequestHandlerClass):
|
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
self.address_family = address_family
|
self.address_family = family
|
||||||
if isinstance(server_address_or_socket, socket.socket):
|
super().__init__(address, RequestHandlerClass)
|
||||||
self._override_socket = server_address_or_socket
|
self.client_sockets = set()
|
||||||
server_address = server_address_or_socket.getsockname()
|
|
||||||
else:
|
|
||||||
self._override_socket = None
|
|
||||||
server_address = server_address_or_socket
|
|
||||||
super().__init__(server_address, RequestHandlerClass)
|
|
||||||
max_connections = self.configuration.get("server", "max_connections")
|
|
||||||
if max_connections:
|
|
||||||
self.connections_guard = threading.BoundedSemaphore(
|
|
||||||
max_connections)
|
|
||||||
else:
|
|
||||||
# use dummy context manager
|
|
||||||
self.connections_guard = contextlib.ExitStack()
|
|
||||||
|
|
||||||
def server_bind(self):
|
def server_bind(self):
|
||||||
if self._override_socket is not None:
|
|
||||||
self.socket = self._override_socket
|
|
||||||
host, port = self.server_address[:2]
|
|
||||||
self.server_name = socket.getfqdn(host)
|
|
||||||
self.server_port = port
|
|
||||||
self.setup_environ()
|
|
||||||
return
|
|
||||||
if self.address_family == socket.AF_INET6:
|
if self.address_family == socket.AF_INET6:
|
||||||
# Only allow IPv6 connections to the IPv6 socket
|
# Only allow IPv6 connections to the IPv6 socket
|
||||||
self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1)
|
self.socket.setsockopt(COMPAT_IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
|
||||||
super().server_bind()
|
super().server_bind()
|
||||||
|
|
||||||
def get_request(self):
|
def get_request(self):
|
||||||
# Set timeout for client
|
# Set timeout for client
|
||||||
socket_, address = super().get_request()
|
request, client_address = super().get_request()
|
||||||
timeout = self.configuration.get("server", "timeout")
|
timeout = self.configuration.get("server", "timeout")
|
||||||
if timeout:
|
if timeout:
|
||||||
socket_.settimeout(timeout)
|
request.settimeout(timeout)
|
||||||
return socket_, address
|
client_socket, client_socket_out = socket.socketpair()
|
||||||
|
self.client_sockets.add(client_socket_out)
|
||||||
|
return request, (*client_address, client_socket)
|
||||||
|
|
||||||
def finish_request_locked(self, request, client_address):
|
def finish_request_locked(self, request, client_address):
|
||||||
return super().finish_request(request, client_address)
|
return super().finish_request(request, client_address)
|
||||||
|
|
||||||
def finish_request(self, request, client_address):
|
def finish_request(self, request, client_address):
|
||||||
"""Don't overwrite this! (Modified by tests.)"""
|
"""Don't overwrite this! (Modified by tests.)"""
|
||||||
with self.connections_guard:
|
*client_address, client_socket = client_address
|
||||||
|
client_address = tuple(client_address)
|
||||||
|
try:
|
||||||
return self.finish_request_locked(request, client_address)
|
return self.finish_request_locked(request, client_address)
|
||||||
|
finally:
|
||||||
|
client_socket.close()
|
||||||
|
|
||||||
def handle_error(self, request, client_address):
|
def handle_error(self, request, client_address):
|
||||||
if issubclass(sys.exc_info()[0], socket.timeout):
|
if issubclass(sys.exc_info()[0], socket.timeout):
|
||||||
@ -139,13 +109,18 @@ class ParallelHTTPSServer(ParallelHTTPServer):
|
|||||||
# Test if the files can be read
|
# Test if the files can be read
|
||||||
for name, filename in [("certificate", certfile), ("key", keyfile),
|
for name, filename in [("certificate", certfile), ("key", keyfile),
|
||||||
("certificate_authority", cafile)]:
|
("certificate_authority", cafile)]:
|
||||||
|
type_name = config.DEFAULT_CONFIG_SCHEMA["server"][name][
|
||||||
|
"type"].__name__
|
||||||
|
source = self.configuration.get_source("server", name)
|
||||||
if name == "certificate_authority" and not filename:
|
if name == "certificate_authority" and not filename:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
open(filename, "r").close()
|
open(filename, "r").close()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise RuntimeError("Failed to read SSL %s %r: %s" %
|
raise RuntimeError(
|
||||||
(name, filename, e)) from e
|
"Invalid %s value for option %r in section %r in %s: %r "
|
||||||
|
"(%s)" % (type_name, name, "server", source, filename,
|
||||||
|
e)) from e
|
||||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
||||||
if cafile:
|
if cafile:
|
||||||
@ -185,7 +160,7 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
|||||||
"""HTTP requests handler."""
|
"""HTTP requests handler."""
|
||||||
|
|
||||||
def log_request(self, code="-", size="-"):
|
def log_request(self, code="-", size="-"):
|
||||||
"""Disable request logging."""
|
pass # Disable request logging.
|
||||||
|
|
||||||
def log_error(self, format_, *args):
|
def log_error(self, format_, *args):
|
||||||
logger.error("An error occurred during request: %s", format_ % args)
|
logger.error("An error occurred during request: %s", format_ % args)
|
||||||
@ -220,7 +195,7 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
|||||||
handler.run(self.server.get_app())
|
handler.run(self.server.get_app())
|
||||||
|
|
||||||
|
|
||||||
def serve(configuration, shutdown_socket=None):
|
def serve(configuration, shutdown_socket):
|
||||||
"""Serve radicale from configuration."""
|
"""Serve radicale from configuration."""
|
||||||
logger.info("Starting Radicale")
|
logger.info("Starting Radicale")
|
||||||
# Copy configuration before modifying
|
# Copy configuration before modifying
|
||||||
@ -228,76 +203,77 @@ def serve(configuration, shutdown_socket=None):
|
|||||||
configuration.update({"internal": {"internal_server": "True"}}, "server",
|
configuration.update({"internal": {"internal_server": "True"}}, "server",
|
||||||
internal=True)
|
internal=True)
|
||||||
|
|
||||||
# Create server sockets
|
use_ssl = configuration.get("server", "ssl")
|
||||||
server_addresses_or_sockets = [] # [((host, port) or socket, family)]
|
server_class = ParallelHTTPSServer if use_ssl else ParallelHTTPServer
|
||||||
if systemd:
|
|
||||||
listen_fds = systemd.daemon.listen_fds()
|
|
||||||
else:
|
|
||||||
listen_fds = []
|
|
||||||
if listen_fds:
|
|
||||||
logger.info("Using socket activation")
|
|
||||||
for fd in listen_fds:
|
|
||||||
server_addresses_or_sockets.append((socket.fromfd(
|
|
||||||
fd, socket.AF_UNIX, socket.SOCK_STREAM), socket.AF_UNIX))
|
|
||||||
else:
|
|
||||||
for address, port in configuration.get("server", "hosts"):
|
|
||||||
server_addresses_or_sockets.append(
|
|
||||||
((address, port), socket.AF_INET))
|
|
||||||
if configuration.get("server", "ssl"):
|
|
||||||
server_class = ParallelHTTPSServer
|
|
||||||
else:
|
|
||||||
server_class = ParallelHTTPServer
|
|
||||||
application = Application(configuration)
|
application = Application(configuration)
|
||||||
servers = {}
|
servers = {}
|
||||||
for server_address_or_socket, family in server_addresses_or_sockets:
|
try:
|
||||||
# If family is AF_INET, try to bind sockets for AF_INET and AF_INET6
|
for address in configuration.get("server", "hosts"):
|
||||||
bind_successful = False
|
# Try to bind sockets for IPv4 and IPv6
|
||||||
for family in [family, socket.AF_INET6]:
|
possible_families = (socket.AF_INET, socket.AF_INET6)
|
||||||
try:
|
bind_ok = False
|
||||||
server = server_class(configuration, family,
|
for i, family in enumerate(possible_families):
|
||||||
server_address_or_socket, RequestHandler)
|
try:
|
||||||
except OSError as e:
|
server = server_class(configuration, family, address,
|
||||||
if ((family == socket.AF_INET and HAS_IPV6 or
|
RequestHandler)
|
||||||
bind_successful) and isinstance(e, socket.gaierror) and
|
except OSError as e:
|
||||||
e.errno in (EAI_NONAME, EAI_ADDRFAMILY)):
|
if ((bind_ok or i < len(possible_families) - 1) and
|
||||||
# Allow one of AF_INET and AF_INET6 to fail, when
|
isinstance(e, socket.gaierror) and
|
||||||
# the address or host don't support the address family.
|
e.errno in (socket.EAI_NONAME,
|
||||||
continue
|
COMPAT_EAI_ADDRFAMILY)):
|
||||||
raise RuntimeError(
|
# Ignore unsupported families, only one must work
|
||||||
"Failed to start server %r: %s" % (
|
continue
|
||||||
server_address_or_socket, e)) from e
|
raise RuntimeError(
|
||||||
bind_successful = True
|
"Failed to start server %r: %s" % (
|
||||||
server.set_app(application)
|
format_address(address), e)) from e
|
||||||
servers[server.socket] = server
|
servers[server.socket] = server
|
||||||
logger.info("Listening to %r on port %d%s (%s)",
|
bind_ok = True
|
||||||
server.server_name, server.server_port, " using SSL"
|
server.set_app(application)
|
||||||
if configuration.get("server", "ssl") else "",
|
logger.info("Listening on %r%s",
|
||||||
family.name)
|
format_address(server.server_address),
|
||||||
|
" with SSL" if use_ssl else "")
|
||||||
|
assert servers, "no servers started"
|
||||||
|
|
||||||
# Main loop: wait for requests on any of the servers or program shutdown
|
# Mainloop
|
||||||
sockets = list(servers.keys())
|
select_timeout = None
|
||||||
# Use socket pair to get notified of program shutdown
|
if os.name == "nt":
|
||||||
if shutdown_socket:
|
# Fallback to busy waiting. (select(...) blocks SIGINT on Windows.)
|
||||||
sockets.append(shutdown_socket)
|
select_timeout = 1.0
|
||||||
select_timeout = None
|
max_connections = configuration.get("server", "max_connections")
|
||||||
if os.name == "nt":
|
logger.info("Radicale server ready")
|
||||||
# Fallback to busy waiting. (select.select blocks SIGINT on Windows.)
|
|
||||||
select_timeout = 1.0
|
|
||||||
logger.info("Radicale server ready")
|
|
||||||
|
|
||||||
with contextlib.ExitStack() as exit_stack:
|
|
||||||
for _, server in servers.items():
|
|
||||||
exit_stack.callback(server.server_close)
|
|
||||||
while True:
|
while True:
|
||||||
rlist, _, xlist = select.select(
|
rlist = xlist = []
|
||||||
sockets, [], sockets, select_timeout)
|
# Wait for finished clients
|
||||||
|
for server in servers.values():
|
||||||
|
rlist.extend(server.client_sockets)
|
||||||
|
# Accept new connections if max_connections is not reached
|
||||||
|
if max_connections <= 0 or len(rlist) < max_connections:
|
||||||
|
rlist.extend(servers)
|
||||||
|
# Use socket to get notified of program shutdown
|
||||||
|
rlist.append(shutdown_socket)
|
||||||
|
rlist, _, xlist = select.select(rlist, [], xlist, select_timeout)
|
||||||
if xlist:
|
if xlist:
|
||||||
raise RuntimeError("unhandled socket error")
|
raise RuntimeError("unhandled socket error")
|
||||||
|
rlist = set(rlist)
|
||||||
if shutdown_socket in rlist:
|
if shutdown_socket in rlist:
|
||||||
logger.info("Stopping Radicale")
|
logger.info("Stopping Radicale")
|
||||||
break
|
break
|
||||||
|
for server in servers.values():
|
||||||
|
finished_sockets = server.client_sockets.intersection(rlist)
|
||||||
|
for s in finished_sockets:
|
||||||
|
s.close()
|
||||||
|
server.client_sockets.remove(s)
|
||||||
|
rlist.remove(s)
|
||||||
|
if finished_sockets:
|
||||||
|
server.service_actions()
|
||||||
if rlist:
|
if rlist:
|
||||||
server = servers.get(rlist[0])
|
server = servers.get(rlist.pop())
|
||||||
if server:
|
if server:
|
||||||
server.handle_request()
|
server.handle_request()
|
||||||
server.service_actions()
|
finally:
|
||||||
|
# Wait for clients to finish and close servers
|
||||||
|
for server in servers.values():
|
||||||
|
for s in server.client_sockets:
|
||||||
|
s.recv(1)
|
||||||
|
s.close()
|
||||||
|
server.server_close()
|
||||||
|
@ -119,15 +119,25 @@ class TestBaseServerRequests(BaseTest):
|
|||||||
self.thread.start()
|
self.thread.start()
|
||||||
self.get("/", check=302)
|
self.get("/", check=302)
|
||||||
|
|
||||||
@pytest.mark.skipif(not server.HAS_IPV6, reason="IPv6 not supported")
|
def test_bind_fail(self):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
with pytest.raises(socket.gaierror) as exc_info:
|
||||||
|
sock.bind(("::1", 0))
|
||||||
|
assert exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY
|
||||||
|
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
|
||||||
|
with pytest.raises(socket.gaierror) as exc_info:
|
||||||
|
sock.bind(("127.0.0.1", 0))
|
||||||
|
assert exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY
|
||||||
|
|
||||||
def test_ipv6(self):
|
def test_ipv6(self):
|
||||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
|
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
|
||||||
sock.setsockopt(server.IPPROTO_IPV6, server.IPV6_V6ONLY, 1)
|
|
||||||
try:
|
try:
|
||||||
# Find available port
|
# Find available port
|
||||||
sock.bind(("::1", 0))
|
sock.bind(("::1", 0))
|
||||||
except OSError:
|
except socket.gaierror as e:
|
||||||
pytest.skip("IPv6 not supported")
|
if e.errno == server.COMPAT_EAI_ADDRFAMILY:
|
||||||
|
pytest.skip("IPv6 not supported")
|
||||||
|
raise
|
||||||
self.sockname = sock.getsockname()[:2]
|
self.sockname = sock.getsockname()[:2]
|
||||||
self.configuration.update({
|
self.configuration.update({
|
||||||
"server": {"hosts": "[%s]:%d" % self.sockname}}, "test")
|
"server": {"hosts": "[%s]:%d" % self.sockname}}, "test")
|
||||||
|
Loading…
Reference in New Issue
Block a user