Refactor: Remove class attributes and subclassing

This commit is contained in:
Unrud 2020-02-19 09:49:44 +01:00
parent a872b633fb
commit 36483670d4

View File

@ -87,32 +87,43 @@ class ParallelHTTPServer(ParallelizationMixIn,
# wait for child processes/threads # wait for child processes/threads
_block_on_close = True _block_on_close = True
# These class attributes must be set before creating instance def __init__(self, configuration, address_family,
client_timeout = None server_address_or_socket, RequestHandlerClass):
max_connections = None self.configuration = configuration
self.address_family = address_family
def __init__(self, *args, **kwargs): if isinstance(server_address_or_socket, socket.socket):
super().__init__(*args, **kwargs) override_socket = server_address_or_socket
server_address = override_socket.getsockname()
else:
override_socket = None
server_address = server_address_or_socket
super().__init__(server_address, RequestHandlerClass,
bind_and_activate=False)
if USE_FORKING: if USE_FORKING:
sema_class = multiprocessing.BoundedSemaphore sema_class = multiprocessing.BoundedSemaphore
else: else:
sema_class = threading.BoundedSemaphore sema_class = threading.BoundedSemaphore
if self.max_connections: max_connections = self.configuration.get("server", "max_connections")
self.connections_guard = sema_class(self.max_connections) if max_connections:
self.connections_guard = sema_class(max_connections)
else: else:
# use dummy context manager # use dummy context manager
self.connections_guard = contextlib.ExitStack() self.connections_guard = contextlib.ExitStack()
if override_socket:
def server_bind(self): self.socket = override_socket
if isinstance(self.server_address, socket.socket):
# Socket activation
self.socket = self.server_address
self.server_address = self.socket.getsockname()
host, port = self.server_address[:2] host, port = self.server_address[:2]
self.server_name = socket.getfqdn(host) self.server_name = socket.getfqdn(host)
self.server_port = port self.server_port = port
self.setup_environ() self.setup_environ()
return return
try:
self.server_bind()
self.server_activate()
except BaseException:
self.server_close()
raise
def server_bind(self):
try: try:
super().server_bind() super().server_bind()
except socket.gaierror as e: except socket.gaierror as e:
@ -129,8 +140,9 @@ class ParallelHTTPServer(ParallelizationMixIn,
def get_request(self): def get_request(self):
# Set timeout for client # Set timeout for client
socket_, address = super().get_request() socket_, address = super().get_request()
if self.client_timeout: timeout = self.configuration.get("server", "timeout")
socket_.settimeout(self.client_timeout) if timeout:
socket_.settimeout(timeout)
return socket_, address return socket_, address
def process_request(self, request, client_address): def process_request(self, request, client_address):
@ -159,18 +171,26 @@ class ParallelHTTPServer(ParallelizationMixIn,
class ParallelHTTPSServer(ParallelHTTPServer): class ParallelHTTPSServer(ParallelHTTPServer):
# These class attributes must be set before creating instance
certificate = None
key = None
certificate_authority = None
def server_bind(self): def server_bind(self):
super().server_bind() super().server_bind()
# Wrap the TCP socket in an SSL socket # Wrap the TCP socket in an SSL socket
certfile = self.configuration.get("server", "certificate")
keyfile = self.configuration.get("server", "key")
cafile = self.configuration.get("server", "certificate_authority")
# Test if the files can be read
for name, filename in [("certificate", certfile), ("key", keyfile),
("certificate_authority", cafile)]:
if name == "certificate_authority" and not filename:
continue
try:
open(filename, "r").close()
except OSError as e:
raise RuntimeError("Failed to read SSL %s %r: %s" %
(name, filename, e)) from e
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=self.certificate, keyfile=self.key) context.load_cert_chain(certfile=certfile, keyfile=keyfile)
if self.certificate_authority: if cafile:
context.load_verify_locations(cafile=self.certificate_authority) context.load_verify_locations(cafile=cafile)
context.verify_mode = ssl.CERT_REQUIRED context.verify_mode = ssl.CERT_REQUIRED
self.socket = context.wrap_socket( self.socket = context.wrap_socket(
self.socket, server_side=True, do_handshake_on_connect=False) self.socket, server_side=True, do_handshake_on_connect=False)
@ -249,57 +269,36 @@ def serve(configuration, shutdown_socket=None):
configuration.update({"internal": {"internal_server": "True"}}, "server", configuration.update({"internal": {"internal_server": "True"}}, "server",
internal=True) internal=True)
# Create collection servers # Create server sockets
servers = {} server_addresses_or_sockets = [] # [((host, port) or socket, family)]
if configuration.get("server", "ssl"):
server_class = ParallelHTTPSServer
else:
server_class = ParallelHTTPServer
class ServerCopy(server_class):
"""Copy, avoids overriding the original class attributes."""
ServerCopy.client_timeout = configuration.get("server", "timeout")
ServerCopy.max_connections = configuration.get("server", "max_connections")
if configuration.get("server", "ssl"):
ServerCopy.certificate = configuration.get("server", "certificate")
ServerCopy.key = configuration.get("server", "key")
ServerCopy.certificate_authority = configuration.get(
"server", "certificate_authority")
# Test if the SSL files can be read
for name in ["certificate", "key"] + (
["certificate_authority"]
if ServerCopy.certificate_authority else []):
filename = getattr(ServerCopy, name)
try:
open(filename, "r").close()
except OSError as e:
raise RuntimeError("Failed to read SSL %s %r: %s" %
(name, filename, e)) from e
if systemd: if systemd:
listen_fds = systemd.daemon.listen_fds() listen_fds = systemd.daemon.listen_fds()
else: else:
listen_fds = [] listen_fds = []
server_addresses = []
if listen_fds: if listen_fds:
logger.info("Using socket activation") logger.info("Using socket activation")
ServerCopy.address_family = socket.AF_UNIX
for fd in listen_fds: for fd in listen_fds:
server_addresses.append(socket.fromfd( server_addresses_or_sockets.append((socket.fromfd(
fd, ServerCopy.address_family, ServerCopy.socket_type)) fd, socket.AF_UNIX, socket.SOCK_STREAM), socket.AF_UNIX))
else: else:
for address, port in configuration.get("server", "hosts"): for address, port in configuration.get("server", "hosts"):
server_addresses.append((address, port)) server_addresses_or_sockets.append(
((address, port), socket.AF_INET))
if configuration.get("server", "ssl"):
server_class = ParallelHTTPSServer
else:
server_class = ParallelHTTPServer
application = Application(configuration) application = Application(configuration)
for server_address in server_addresses: servers = {}
for server_address_or_socket, family in server_addresses_or_sockets:
try: try:
server = ServerCopy(server_address, RequestHandler) server = server_class(configuration, family,
server_address_or_socket, RequestHandler)
server.set_app(application) server.set_app(application)
except OSError as e: except OSError as e:
raise RuntimeError( raise RuntimeError(
"Failed to start server %r: %s" % (server_address, e)) from e "Failed to start server %r: %s" % (
server_address_or_socket, e)) from e
servers[server.socket] = server servers[server.socket] = server
logger.info("Listening to %r on port %d%s", logger.info("Listening to %r on port %d%s",
server.server_name, server.server_port, " using SSL" server.server_name, server.server_port, " using SSL"