Handle disabled IPv6 support and workaround for PyPy

This commit is contained in:
Unrud 2020-02-20 07:57:39 +01:00
parent 9603aa3496
commit 8890a4c030
2 changed files with 30 additions and 20 deletions

View File

@ -22,6 +22,7 @@ Built-in WSGI server.
""" """
import errno
import os import os
import select import select
import socket import socket
@ -36,8 +37,8 @@ from radicale.log import logger
if hasattr(socket, "EAI_ADDRFAMILY"): if hasattr(socket, "EAI_ADDRFAMILY"):
COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY
elif os.name == "nt" and hasattr(socket, "EAI_NONAME"): elif hasattr(socket, "EAI_NONAME"):
# Windows doesn't have a special error code for this # Windows and BSD don't have a special error code for this
COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME
if hasattr(socket, "IPPROTO_IPV6"): if hasattr(socket, "IPPROTO_IPV6"):
COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6 COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6
@ -213,18 +214,25 @@ def serve(configuration, shutdown_socket):
possible_families = (socket.AF_INET, socket.AF_INET6) possible_families = (socket.AF_INET, socket.AF_INET6)
bind_ok = False bind_ok = False
for i, family in enumerate(possible_families): for i, family in enumerate(possible_families):
is_last = i == len(possible_families) - 1
try: try:
server = server_class(configuration, family, address, server = server_class(configuration, family, address,
RequestHandler) RequestHandler)
except OSError as e: except OSError as e:
if ((bind_ok or i < len(possible_families) - 1) and # Ignore unsupported families (only one must work)
isinstance(e, socket.gaierror) and if ((bind_ok or not is_last) and (
e.errno in (socket.EAI_NONAME, isinstance(e, socket.gaierror) and (
COMPAT_EAI_ADDRFAMILY)): # Hostname does not exist or doesn't have
# Ignore unsupported families, only one must work # address for address family
e.errno == socket.EAI_NONAME or
# Address not for address family
e.errno == COMPAT_EAI_ADDRFAMILY) or
# Workaround for PyPy
str(e) == "address family mismatched" or
# Address family not available (e.g. IPv6 disabled)
e.errno == errno.EADDRNOTAVAIL)):
continue continue
raise RuntimeError( raise RuntimeError("Failed to start server %r: %s" % (
"Failed to start server %r: %s" % (
format_address(address), e)) from e format_address(address), e)) from e
servers[server.socket] = server servers[server.socket] = server
bind_ok = True bind_ok = True

View File

@ -19,6 +19,7 @@ Test the internal server.
""" """
import errno
import os import os
import shutil import shutil
import socket import socket
@ -116,22 +117,23 @@ class TestBaseServerRequests(BaseTest):
self.get("/", check=302) self.get("/", check=302)
def test_bind_fail(self): def test_bind_fail(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: for family, address in [(socket.AF_INET, "::1"),
with pytest.raises(socket.gaierror) as exc_info: (socket.AF_INET6, "127.0.0.1")]:
sock.bind(("::1", 0)) with socket.socket(family, socket.SOCK_STREAM) as sock:
assert exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY with pytest.raises(OSError) as exc_info:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: sock.bind((address, 0))
with pytest.raises(socket.gaierror) as exc_info: assert (isinstance(exc_info.value, socket.gaierror) and
sock.bind(("127.0.0.1", 0)) exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY or
assert exc_info.value.errno == server.COMPAT_EAI_ADDRFAMILY # Workaround for PyPy
str(exc_info.value) == "address family mismatched")
def test_ipv6(self): def test_ipv6(self):
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
try: try:
# Find available port # Find available port
sock.bind(("::1", 0)) sock.bind(("::1", 0))
except socket.gaierror as e: except OSError as e:
if e.errno == server.COMPAT_EAI_ADDRFAMILY: if e.errno == errno.EADDRNOTAVAIL:
pytest.skip("IPv6 not supported") pytest.skip("IPv6 not supported")
raise raise
self.sockname = sock.getsockname()[:2] self.sockname = sock.getsockname()[:2]