Refactor: Remove class attributes and subclassing
This commit is contained in:
parent
a872b633fb
commit
36483670d4
@ -87,32 +87,43 @@ class ParallelHTTPServer(ParallelizationMixIn,
|
||||
# wait for child processes/threads
|
||||
_block_on_close = True
|
||||
|
||||
# These class attributes must be set before creating instance
|
||||
client_timeout = None
|
||||
max_connections = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, configuration, address_family,
|
||||
server_address_or_socket, RequestHandlerClass):
|
||||
self.configuration = configuration
|
||||
self.address_family = address_family
|
||||
if isinstance(server_address_or_socket, socket.socket):
|
||||
override_socket = server_address_or_socket
|
||||
server_address = override_socket.getsockname()
|
||||
else:
|
||||
override_socket = None
|
||||
server_address = server_address_or_socket
|
||||
super().__init__(server_address, RequestHandlerClass,
|
||||
bind_and_activate=False)
|
||||
if USE_FORKING:
|
||||
sema_class = multiprocessing.BoundedSemaphore
|
||||
else:
|
||||
sema_class = threading.BoundedSemaphore
|
||||
if self.max_connections:
|
||||
self.connections_guard = sema_class(self.max_connections)
|
||||
max_connections = self.configuration.get("server", "max_connections")
|
||||
if max_connections:
|
||||
self.connections_guard = sema_class(max_connections)
|
||||
else:
|
||||
# use dummy context manager
|
||||
self.connections_guard = contextlib.ExitStack()
|
||||
|
||||
def server_bind(self):
|
||||
if isinstance(self.server_address, socket.socket):
|
||||
# Socket activation
|
||||
self.socket = self.server_address
|
||||
self.server_address = self.socket.getsockname()
|
||||
if override_socket:
|
||||
self.socket = override_socket
|
||||
host, port = self.server_address[:2]
|
||||
self.server_name = socket.getfqdn(host)
|
||||
self.server_port = port
|
||||
self.setup_environ()
|
||||
return
|
||||
try:
|
||||
self.server_bind()
|
||||
self.server_activate()
|
||||
except BaseException:
|
||||
self.server_close()
|
||||
raise
|
||||
|
||||
def server_bind(self):
|
||||
try:
|
||||
super().server_bind()
|
||||
except socket.gaierror as e:
|
||||
@ -129,8 +140,9 @@ class ParallelHTTPServer(ParallelizationMixIn,
|
||||
def get_request(self):
|
||||
# Set timeout for client
|
||||
socket_, address = super().get_request()
|
||||
if self.client_timeout:
|
||||
socket_.settimeout(self.client_timeout)
|
||||
timeout = self.configuration.get("server", "timeout")
|
||||
if timeout:
|
||||
socket_.settimeout(timeout)
|
||||
return socket_, address
|
||||
|
||||
def process_request(self, request, client_address):
|
||||
@ -159,18 +171,26 @@ class ParallelHTTPServer(ParallelizationMixIn,
|
||||
|
||||
class ParallelHTTPSServer(ParallelHTTPServer):
|
||||
|
||||
# These class attributes must be set before creating instance
|
||||
certificate = None
|
||||
key = None
|
||||
certificate_authority = None
|
||||
|
||||
def server_bind(self):
|
||||
super().server_bind()
|
||||
# Wrap the TCP socket in an SSL socket
|
||||
certfile = self.configuration.get("server", "certificate")
|
||||
keyfile = self.configuration.get("server", "key")
|
||||
cafile = self.configuration.get("server", "certificate_authority")
|
||||
# Test if the files can be read
|
||||
for name, filename in [("certificate", certfile), ("key", keyfile),
|
||||
("certificate_authority", cafile)]:
|
||||
if name == "certificate_authority" and not filename:
|
||||
continue
|
||||
try:
|
||||
open(filename, "r").close()
|
||||
except OSError as e:
|
||||
raise RuntimeError("Failed to read SSL %s %r: %s" %
|
||||
(name, filename, e)) from e
|
||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
context.load_cert_chain(certfile=self.certificate, keyfile=self.key)
|
||||
if self.certificate_authority:
|
||||
context.load_verify_locations(cafile=self.certificate_authority)
|
||||
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
||||
if cafile:
|
||||
context.load_verify_locations(cafile=cafile)
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
self.socket = context.wrap_socket(
|
||||
self.socket, server_side=True, do_handshake_on_connect=False)
|
||||
@ -249,57 +269,36 @@ def serve(configuration, shutdown_socket=None):
|
||||
configuration.update({"internal": {"internal_server": "True"}}, "server",
|
||||
internal=True)
|
||||
|
||||
# Create collection servers
|
||||
servers = {}
|
||||
if configuration.get("server", "ssl"):
|
||||
server_class = ParallelHTTPSServer
|
||||
else:
|
||||
server_class = ParallelHTTPServer
|
||||
|
||||
class ServerCopy(server_class):
|
||||
"""Copy, avoids overriding the original class attributes."""
|
||||
ServerCopy.client_timeout = configuration.get("server", "timeout")
|
||||
ServerCopy.max_connections = configuration.get("server", "max_connections")
|
||||
if configuration.get("server", "ssl"):
|
||||
ServerCopy.certificate = configuration.get("server", "certificate")
|
||||
ServerCopy.key = configuration.get("server", "key")
|
||||
ServerCopy.certificate_authority = configuration.get(
|
||||
"server", "certificate_authority")
|
||||
# Test if the SSL files can be read
|
||||
for name in ["certificate", "key"] + (
|
||||
["certificate_authority"]
|
||||
if ServerCopy.certificate_authority else []):
|
||||
filename = getattr(ServerCopy, name)
|
||||
try:
|
||||
open(filename, "r").close()
|
||||
except OSError as e:
|
||||
raise RuntimeError("Failed to read SSL %s %r: %s" %
|
||||
(name, filename, e)) from e
|
||||
|
||||
# Create server sockets
|
||||
server_addresses_or_sockets = [] # [((host, port) or socket, family)]
|
||||
if systemd:
|
||||
listen_fds = systemd.daemon.listen_fds()
|
||||
else:
|
||||
listen_fds = []
|
||||
|
||||
server_addresses = []
|
||||
if listen_fds:
|
||||
logger.info("Using socket activation")
|
||||
ServerCopy.address_family = socket.AF_UNIX
|
||||
for fd in listen_fds:
|
||||
server_addresses.append(socket.fromfd(
|
||||
fd, ServerCopy.address_family, ServerCopy.socket_type))
|
||||
server_addresses_or_sockets.append((socket.fromfd(
|
||||
fd, socket.AF_UNIX, socket.SOCK_STREAM), socket.AF_UNIX))
|
||||
else:
|
||||
for address, port in configuration.get("server", "hosts"):
|
||||
server_addresses.append((address, port))
|
||||
|
||||
server_addresses_or_sockets.append(
|
||||
((address, port), socket.AF_INET))
|
||||
if configuration.get("server", "ssl"):
|
||||
server_class = ParallelHTTPSServer
|
||||
else:
|
||||
server_class = ParallelHTTPServer
|
||||
application = Application(configuration)
|
||||
for server_address in server_addresses:
|
||||
servers = {}
|
||||
for server_address_or_socket, family in server_addresses_or_sockets:
|
||||
try:
|
||||
server = ServerCopy(server_address, RequestHandler)
|
||||
server = server_class(configuration, family,
|
||||
server_address_or_socket, RequestHandler)
|
||||
server.set_app(application)
|
||||
except OSError as e:
|
||||
raise RuntimeError(
|
||||
"Failed to start server %r: %s" % (server_address, e)) from e
|
||||
"Failed to start server %r: %s" % (
|
||||
server_address_or_socket, e)) from e
|
||||
servers[server.socket] = server
|
||||
logger.info("Listening to %r on port %d%s",
|
||||
server.server_name, server.server_port, " using SSL"
|
||||
|
Loading…
Reference in New Issue
Block a user