Refactor: Remove class attributes and subclassing

This commit is contained in:
Unrud 2020-02-19 09:49:44 +01:00
parent a872b633fb
commit 36483670d4

View File

@ -87,32 +87,43 @@ class ParallelHTTPServer(ParallelizationMixIn,
# wait for child processes/threads
_block_on_close = True
# These class attributes must be set before creating instance
client_timeout = None
max_connections = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, configuration, address_family,
server_address_or_socket, RequestHandlerClass):
self.configuration = configuration
self.address_family = address_family
if isinstance(server_address_or_socket, socket.socket):
override_socket = server_address_or_socket
server_address = override_socket.getsockname()
else:
override_socket = None
server_address = server_address_or_socket
super().__init__(server_address, RequestHandlerClass,
bind_and_activate=False)
if USE_FORKING:
sema_class = multiprocessing.BoundedSemaphore
else:
sema_class = threading.BoundedSemaphore
if self.max_connections:
self.connections_guard = sema_class(self.max_connections)
max_connections = self.configuration.get("server", "max_connections")
if max_connections:
self.connections_guard = sema_class(max_connections)
else:
# use dummy context manager
self.connections_guard = contextlib.ExitStack()
def server_bind(self):
if isinstance(self.server_address, socket.socket):
# Socket activation
self.socket = self.server_address
self.server_address = self.socket.getsockname()
if override_socket:
self.socket = override_socket
host, port = self.server_address[:2]
self.server_name = socket.getfqdn(host)
self.server_port = port
self.setup_environ()
return
try:
self.server_bind()
self.server_activate()
except BaseException:
self.server_close()
raise
def server_bind(self):
try:
super().server_bind()
except socket.gaierror as e:
@ -129,8 +140,9 @@ class ParallelHTTPServer(ParallelizationMixIn,
def get_request(self):
# Set timeout for client
socket_, address = super().get_request()
if self.client_timeout:
socket_.settimeout(self.client_timeout)
timeout = self.configuration.get("server", "timeout")
if timeout:
socket_.settimeout(timeout)
return socket_, address
def process_request(self, request, client_address):
@ -159,18 +171,26 @@ class ParallelHTTPServer(ParallelizationMixIn,
class ParallelHTTPSServer(ParallelHTTPServer):
# These class attributes must be set before creating instance
certificate = None
key = None
certificate_authority = None
def server_bind(self):
super().server_bind()
# Wrap the TCP socket in an SSL socket
certfile = self.configuration.get("server", "certificate")
keyfile = self.configuration.get("server", "key")
cafile = self.configuration.get("server", "certificate_authority")
# Test if the files can be read
for name, filename in [("certificate", certfile), ("key", keyfile),
("certificate_authority", cafile)]:
if name == "certificate_authority" and not filename:
continue
try:
open(filename, "r").close()
except OSError as e:
raise RuntimeError("Failed to read SSL %s %r: %s" %
(name, filename, e)) from e
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=self.certificate, keyfile=self.key)
if self.certificate_authority:
context.load_verify_locations(cafile=self.certificate_authority)
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
if cafile:
context.load_verify_locations(cafile=cafile)
context.verify_mode = ssl.CERT_REQUIRED
self.socket = context.wrap_socket(
self.socket, server_side=True, do_handshake_on_connect=False)
@ -249,57 +269,36 @@ def serve(configuration, shutdown_socket=None):
configuration.update({"internal": {"internal_server": "True"}}, "server",
internal=True)
# Create collection servers
servers = {}
if configuration.get("server", "ssl"):
server_class = ParallelHTTPSServer
else:
server_class = ParallelHTTPServer
class ServerCopy(server_class):
"""Copy, avoids overriding the original class attributes."""
ServerCopy.client_timeout = configuration.get("server", "timeout")
ServerCopy.max_connections = configuration.get("server", "max_connections")
if configuration.get("server", "ssl"):
ServerCopy.certificate = configuration.get("server", "certificate")
ServerCopy.key = configuration.get("server", "key")
ServerCopy.certificate_authority = configuration.get(
"server", "certificate_authority")
# Test if the SSL files can be read
for name in ["certificate", "key"] + (
["certificate_authority"]
if ServerCopy.certificate_authority else []):
filename = getattr(ServerCopy, name)
try:
open(filename, "r").close()
except OSError as e:
raise RuntimeError("Failed to read SSL %s %r: %s" %
(name, filename, e)) from e
# Create server sockets
server_addresses_or_sockets = [] # [((host, port) or socket, family)]
if systemd:
listen_fds = systemd.daemon.listen_fds()
else:
listen_fds = []
server_addresses = []
if listen_fds:
logger.info("Using socket activation")
ServerCopy.address_family = socket.AF_UNIX
for fd in listen_fds:
server_addresses.append(socket.fromfd(
fd, ServerCopy.address_family, ServerCopy.socket_type))
server_addresses_or_sockets.append((socket.fromfd(
fd, socket.AF_UNIX, socket.SOCK_STREAM), socket.AF_UNIX))
else:
for address, port in configuration.get("server", "hosts"):
server_addresses.append((address, port))
server_addresses_or_sockets.append(
((address, port), socket.AF_INET))
if configuration.get("server", "ssl"):
server_class = ParallelHTTPSServer
else:
server_class = ParallelHTTPServer
application = Application(configuration)
for server_address in server_addresses:
servers = {}
for server_address_or_socket, family in server_addresses_or_sockets:
try:
server = ServerCopy(server_address, RequestHandler)
server = server_class(configuration, family,
server_address_or_socket, RequestHandler)
server.set_app(application)
except OSError as e:
raise RuntimeError(
"Failed to start server %r: %s" % (server_address, e)) from e
"Failed to start server %r: %s" % (
server_address_or_socket, e)) from e
servers[server.socket] = server
logger.info("Listening to %r on port %d%s",
server.server_name, server.server_port, " using SSL"