Refactor: Remove class attributes and subclassing
This commit is contained in:
parent
a872b633fb
commit
36483670d4
@ -87,32 +87,43 @@ class ParallelHTTPServer(ParallelizationMixIn,
|
|||||||
# wait for child processes/threads
|
# wait for child processes/threads
|
||||||
_block_on_close = True
|
_block_on_close = True
|
||||||
|
|
||||||
# These class attributes must be set before creating instance
|
def __init__(self, configuration, address_family,
|
||||||
client_timeout = None
|
server_address_or_socket, RequestHandlerClass):
|
||||||
max_connections = None
|
self.configuration = configuration
|
||||||
|
self.address_family = address_family
|
||||||
def __init__(self, *args, **kwargs):
|
if isinstance(server_address_or_socket, socket.socket):
|
||||||
super().__init__(*args, **kwargs)
|
override_socket = server_address_or_socket
|
||||||
|
server_address = override_socket.getsockname()
|
||||||
|
else:
|
||||||
|
override_socket = None
|
||||||
|
server_address = server_address_or_socket
|
||||||
|
super().__init__(server_address, RequestHandlerClass,
|
||||||
|
bind_and_activate=False)
|
||||||
if USE_FORKING:
|
if USE_FORKING:
|
||||||
sema_class = multiprocessing.BoundedSemaphore
|
sema_class = multiprocessing.BoundedSemaphore
|
||||||
else:
|
else:
|
||||||
sema_class = threading.BoundedSemaphore
|
sema_class = threading.BoundedSemaphore
|
||||||
if self.max_connections:
|
max_connections = self.configuration.get("server", "max_connections")
|
||||||
self.connections_guard = sema_class(self.max_connections)
|
if max_connections:
|
||||||
|
self.connections_guard = sema_class(max_connections)
|
||||||
else:
|
else:
|
||||||
# use dummy context manager
|
# use dummy context manager
|
||||||
self.connections_guard = contextlib.ExitStack()
|
self.connections_guard = contextlib.ExitStack()
|
||||||
|
if override_socket:
|
||||||
def server_bind(self):
|
self.socket = override_socket
|
||||||
if isinstance(self.server_address, socket.socket):
|
|
||||||
# Socket activation
|
|
||||||
self.socket = self.server_address
|
|
||||||
self.server_address = self.socket.getsockname()
|
|
||||||
host, port = self.server_address[:2]
|
host, port = self.server_address[:2]
|
||||||
self.server_name = socket.getfqdn(host)
|
self.server_name = socket.getfqdn(host)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
self.setup_environ()
|
self.setup_environ()
|
||||||
return
|
return
|
||||||
|
try:
|
||||||
|
self.server_bind()
|
||||||
|
self.server_activate()
|
||||||
|
except BaseException:
|
||||||
|
self.server_close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def server_bind(self):
|
||||||
try:
|
try:
|
||||||
super().server_bind()
|
super().server_bind()
|
||||||
except socket.gaierror as e:
|
except socket.gaierror as e:
|
||||||
@ -129,8 +140,9 @@ class ParallelHTTPServer(ParallelizationMixIn,
|
|||||||
def get_request(self):
|
def get_request(self):
|
||||||
# Set timeout for client
|
# Set timeout for client
|
||||||
socket_, address = super().get_request()
|
socket_, address = super().get_request()
|
||||||
if self.client_timeout:
|
timeout = self.configuration.get("server", "timeout")
|
||||||
socket_.settimeout(self.client_timeout)
|
if timeout:
|
||||||
|
socket_.settimeout(timeout)
|
||||||
return socket_, address
|
return socket_, address
|
||||||
|
|
||||||
def process_request(self, request, client_address):
|
def process_request(self, request, client_address):
|
||||||
@ -159,18 +171,26 @@ class ParallelHTTPServer(ParallelizationMixIn,
|
|||||||
|
|
||||||
class ParallelHTTPSServer(ParallelHTTPServer):
|
class ParallelHTTPSServer(ParallelHTTPServer):
|
||||||
|
|
||||||
# These class attributes must be set before creating instance
|
|
||||||
certificate = None
|
|
||||||
key = None
|
|
||||||
certificate_authority = None
|
|
||||||
|
|
||||||
def server_bind(self):
|
def server_bind(self):
|
||||||
super().server_bind()
|
super().server_bind()
|
||||||
# Wrap the TCP socket in an SSL socket
|
# Wrap the TCP socket in an SSL socket
|
||||||
|
certfile = self.configuration.get("server", "certificate")
|
||||||
|
keyfile = self.configuration.get("server", "key")
|
||||||
|
cafile = self.configuration.get("server", "certificate_authority")
|
||||||
|
# Test if the files can be read
|
||||||
|
for name, filename in [("certificate", certfile), ("key", keyfile),
|
||||||
|
("certificate_authority", cafile)]:
|
||||||
|
if name == "certificate_authority" and not filename:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
open(filename, "r").close()
|
||||||
|
except OSError as e:
|
||||||
|
raise RuntimeError("Failed to read SSL %s %r: %s" %
|
||||||
|
(name, filename, e)) from e
|
||||||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
context.load_cert_chain(certfile=self.certificate, keyfile=self.key)
|
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
|
||||||
if self.certificate_authority:
|
if cafile:
|
||||||
context.load_verify_locations(cafile=self.certificate_authority)
|
context.load_verify_locations(cafile=cafile)
|
||||||
context.verify_mode = ssl.CERT_REQUIRED
|
context.verify_mode = ssl.CERT_REQUIRED
|
||||||
self.socket = context.wrap_socket(
|
self.socket = context.wrap_socket(
|
||||||
self.socket, server_side=True, do_handshake_on_connect=False)
|
self.socket, server_side=True, do_handshake_on_connect=False)
|
||||||
@ -249,57 +269,36 @@ def serve(configuration, shutdown_socket=None):
|
|||||||
configuration.update({"internal": {"internal_server": "True"}}, "server",
|
configuration.update({"internal": {"internal_server": "True"}}, "server",
|
||||||
internal=True)
|
internal=True)
|
||||||
|
|
||||||
# Create collection servers
|
# Create server sockets
|
||||||
servers = {}
|
server_addresses_or_sockets = [] # [((host, port) or socket, family)]
|
||||||
if configuration.get("server", "ssl"):
|
|
||||||
server_class = ParallelHTTPSServer
|
|
||||||
else:
|
|
||||||
server_class = ParallelHTTPServer
|
|
||||||
|
|
||||||
class ServerCopy(server_class):
|
|
||||||
"""Copy, avoids overriding the original class attributes."""
|
|
||||||
ServerCopy.client_timeout = configuration.get("server", "timeout")
|
|
||||||
ServerCopy.max_connections = configuration.get("server", "max_connections")
|
|
||||||
if configuration.get("server", "ssl"):
|
|
||||||
ServerCopy.certificate = configuration.get("server", "certificate")
|
|
||||||
ServerCopy.key = configuration.get("server", "key")
|
|
||||||
ServerCopy.certificate_authority = configuration.get(
|
|
||||||
"server", "certificate_authority")
|
|
||||||
# Test if the SSL files can be read
|
|
||||||
for name in ["certificate", "key"] + (
|
|
||||||
["certificate_authority"]
|
|
||||||
if ServerCopy.certificate_authority else []):
|
|
||||||
filename = getattr(ServerCopy, name)
|
|
||||||
try:
|
|
||||||
open(filename, "r").close()
|
|
||||||
except OSError as e:
|
|
||||||
raise RuntimeError("Failed to read SSL %s %r: %s" %
|
|
||||||
(name, filename, e)) from e
|
|
||||||
|
|
||||||
if systemd:
|
if systemd:
|
||||||
listen_fds = systemd.daemon.listen_fds()
|
listen_fds = systemd.daemon.listen_fds()
|
||||||
else:
|
else:
|
||||||
listen_fds = []
|
listen_fds = []
|
||||||
|
|
||||||
server_addresses = []
|
|
||||||
if listen_fds:
|
if listen_fds:
|
||||||
logger.info("Using socket activation")
|
logger.info("Using socket activation")
|
||||||
ServerCopy.address_family = socket.AF_UNIX
|
|
||||||
for fd in listen_fds:
|
for fd in listen_fds:
|
||||||
server_addresses.append(socket.fromfd(
|
server_addresses_or_sockets.append((socket.fromfd(
|
||||||
fd, ServerCopy.address_family, ServerCopy.socket_type))
|
fd, socket.AF_UNIX, socket.SOCK_STREAM), socket.AF_UNIX))
|
||||||
else:
|
else:
|
||||||
for address, port in configuration.get("server", "hosts"):
|
for address, port in configuration.get("server", "hosts"):
|
||||||
server_addresses.append((address, port))
|
server_addresses_or_sockets.append(
|
||||||
|
((address, port), socket.AF_INET))
|
||||||
|
if configuration.get("server", "ssl"):
|
||||||
|
server_class = ParallelHTTPSServer
|
||||||
|
else:
|
||||||
|
server_class = ParallelHTTPServer
|
||||||
application = Application(configuration)
|
application = Application(configuration)
|
||||||
for server_address in server_addresses:
|
servers = {}
|
||||||
|
for server_address_or_socket, family in server_addresses_or_sockets:
|
||||||
try:
|
try:
|
||||||
server = ServerCopy(server_address, RequestHandler)
|
server = server_class(configuration, family,
|
||||||
|
server_address_or_socket, RequestHandler)
|
||||||
server.set_app(application)
|
server.set_app(application)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Failed to start server %r: %s" % (server_address, e)) from e
|
"Failed to start server %r: %s" % (
|
||||||
|
server_address_or_socket, e)) from e
|
||||||
servers[server.socket] = server
|
servers[server.socket] = server
|
||||||
logger.info("Listening to %r on port %d%s",
|
logger.info("Listening to %r on port %d%s",
|
||||||
server.server_name, server.server_port, " using SSL"
|
server.server_name, server.server_port, " using SSL"
|
||||||
|
Loading…
Reference in New Issue
Block a user