Use forking for internal server when available
This commit is contained in:
		| @@ -25,6 +25,7 @@ http://docs.python.org/library/logging.config.html | |||||||
| import contextlib | import contextlib | ||||||
| import io | import io | ||||||
| import logging | import logging | ||||||
|  | import multiprocessing | ||||||
| import os | import os | ||||||
| import sys | import sys | ||||||
| import threading | import threading | ||||||
| @@ -35,7 +36,7 @@ except ImportError: | |||||||
|     journal = None |     journal = None | ||||||
|  |  | ||||||
| LOGGER_NAME = "radicale" | LOGGER_NAME = "radicale" | ||||||
| LOGGER_FORMAT = "[%(processName)s/%(threadName)s] %(levelname)s: %(message)s" | LOGGER_FORMAT = "[%(ident)s] %(levelname)s: %(message)s" | ||||||
|  |  | ||||||
| root_logger = logging.getLogger() | root_logger = logging.getLogger() | ||||||
| logger = logging.getLogger(LOGGER_NAME) | logger = logging.getLogger(LOGGER_NAME) | ||||||
| @@ -50,6 +51,27 @@ class RemoveTracebackFilter(logging.Filter): | |||||||
| removeTracebackFilter = RemoveTracebackFilter() | removeTracebackFilter = RemoveTracebackFilter() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class IdentLogRecordFactory: | ||||||
|  |     """LogRecordFactory that adds ``ident`` attribute.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, upstream_factory): | ||||||
|  |         self.upstream_factory = upstream_factory | ||||||
|  |         self.main_pid = os.getpid() | ||||||
|  |         self.main_thread_name = threading.current_thread().name | ||||||
|  |  | ||||||
|  |     def __call__(self, *args, **kwargs): | ||||||
|  |         record = self.upstream_factory(*args, **kwargs) | ||||||
|  |         pid = os.getpid() | ||||||
|  |         thread_name = threading.current_thread().name | ||||||
|  |         ident = "%x" % self.main_pid | ||||||
|  |         if pid != self.main_pid: | ||||||
|  |             ident += "%+x" % (pid - self.main_pid) | ||||||
|  |         if thread_name != self.main_thread_name: | ||||||
|  |             ident += "/%s" % thread_name | ||||||
|  |         record.ident = ident | ||||||
|  |         return record | ||||||
|  |  | ||||||
|  |  | ||||||
| class ThreadStreamsHandler(logging.Handler): | class ThreadStreamsHandler(logging.Handler): | ||||||
|  |  | ||||||
|     terminator = "\n" |     terminator = "\n" | ||||||
| @@ -60,6 +82,9 @@ class ThreadStreamsHandler(logging.Handler): | |||||||
|         self.fallback_stream = fallback_stream |         self.fallback_stream = fallback_stream | ||||||
|         self.fallback_handler = fallback_handler |         self.fallback_handler = fallback_handler | ||||||
|  |  | ||||||
|  |     def createLock(self): | ||||||
|  |         self.lock = multiprocessing.Lock() | ||||||
|  |  | ||||||
|     def setFormatter(self, form): |     def setFormatter(self, form): | ||||||
|         super().setFormatter(form) |         super().setFormatter(form) | ||||||
|         self.fallback_handler.setFormatter(form) |         self.fallback_handler.setFormatter(form) | ||||||
| @@ -116,6 +141,8 @@ def setup(): | |||||||
|     handler = ThreadStreamsHandler(sys.stderr, get_default_handler()) |     handler = ThreadStreamsHandler(sys.stderr, get_default_handler()) | ||||||
|     logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler]) |     logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler]) | ||||||
|     register_stream = handler.register_stream |     register_stream = handler.register_stream | ||||||
|  |     log_record_factory = IdentLogRecordFactory(logging.getLogRecordFactory()) | ||||||
|  |     logging.setLogRecordFactory(log_record_factory) | ||||||
|     set_level(logging.DEBUG) |     set_level(logging.DEBUG) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -22,6 +22,7 @@ Radicale WSGI server. | |||||||
| """ | """ | ||||||
|  |  | ||||||
| import contextlib | import contextlib | ||||||
|  | import multiprocessing | ||||||
| import os | import os | ||||||
| import select | import select | ||||||
| import signal | import signal | ||||||
| @@ -29,16 +30,20 @@ import socket | |||||||
| import socketserver | import socketserver | ||||||
| import ssl | import ssl | ||||||
| import sys | import sys | ||||||
| import threading |  | ||||||
| import wsgiref.simple_server | import wsgiref.simple_server | ||||||
| from urllib.parse import unquote | from urllib.parse import unquote | ||||||
|  |  | ||||||
| from radicale import Application | from radicale import Application | ||||||
| from radicale.log import logger | from radicale.log import logger | ||||||
|  |  | ||||||
|  | if hasattr(socketserver, "ForkingMixIn"): | ||||||
|  |     ParallelizationMixIn = socketserver.ForkingMixIn | ||||||
|  | else: | ||||||
|  |     ParallelizationMixIn = socketserver.ThreadingMixIn | ||||||
|  |  | ||||||
| class HTTPServer(wsgiref.simple_server.WSGIServer): |  | ||||||
|     """HTTP server.""" | class ParallelHTTPServer(ParallelizationMixIn, | ||||||
|  |                          wsgiref.simple_server.WSGIServer): | ||||||
|  |  | ||||||
|     # These class attributes must be set before creating instance |     # These class attributes must be set before creating instance | ||||||
|     client_timeout = None |     client_timeout = None | ||||||
| @@ -59,7 +64,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer): | |||||||
|             self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) |             self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) | ||||||
|  |  | ||||||
|         if self.max_connections: |         if self.max_connections: | ||||||
|             self.connections_guard = threading.BoundedSemaphore( |             self.connections_guard = multiprocessing.BoundedSemaphore( | ||||||
|                 self.max_connections) |                 self.max_connections) | ||||||
|         else: |         else: | ||||||
|             # use dummy context manager |             # use dummy context manager | ||||||
| @@ -75,10 +80,14 @@ class HTTPServer(wsgiref.simple_server.WSGIServer): | |||||||
|  |  | ||||||
|     def get_request(self): |     def get_request(self): | ||||||
|         # Set timeout for client |         # Set timeout for client | ||||||
|         _socket, address = super().get_request() |         socket_, address = super().get_request() | ||||||
|         if self.client_timeout: |         if self.client_timeout: | ||||||
|             _socket.settimeout(self.client_timeout) |             socket_.settimeout(self.client_timeout) | ||||||
|         return _socket, address |         return socket_, address | ||||||
|  |  | ||||||
|  |     def finish_request(self, request, client_address): | ||||||
|  |         with self.connections_guard: | ||||||
|  |             return super().finish_request(request, client_address) | ||||||
|  |  | ||||||
|     def handle_error(self, request, client_address): |     def handle_error(self, request, client_address): | ||||||
|         if issubclass(sys.exc_info()[0], socket.timeout): |         if issubclass(sys.exc_info()[0], socket.timeout): | ||||||
| @@ -88,8 +97,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer): | |||||||
|                          sys.exc_info()[1], exc_info=True) |                          sys.exc_info()[1], exc_info=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| class HTTPSServer(HTTPServer): | class ParallelHTTPSServer(ParallelHTTPServer): | ||||||
|     """HTTPS server.""" |  | ||||||
|  |  | ||||||
|     # These class attributes must be set before creating instance |     # These class attributes must be set before creating instance | ||||||
|     certificate = None |     certificate = None | ||||||
| @@ -98,9 +106,11 @@ class HTTPSServer(HTTPServer): | |||||||
|     ciphers = None |     ciphers = None | ||||||
|     certificate_authority = None |     certificate_authority = None | ||||||
|  |  | ||||||
|     def __init__(self, address, handler): |     def __init__(self, address, handler, bind_and_activate=True): | ||||||
|         """Create server by wrapping HTTP socket in an SSL socket.""" |         """Create server by wrapping HTTP socket in an SSL socket.""" | ||||||
|         super().__init__(address, handler, bind_and_activate=False) |  | ||||||
|  |         # Do not bind and activate, as we change the socket | ||||||
|  |         super().__init__(address, handler, False) | ||||||
|  |  | ||||||
|         self.socket = ssl.wrap_socket( |         self.socket = ssl.wrap_socket( | ||||||
|             self.socket, self.key, self.certificate, server_side=True, |             self.socket, self.key, self.certificate, server_side=True, | ||||||
| @@ -110,18 +120,15 @@ class HTTPSServer(HTTPServer): | |||||||
|             ssl_version=self.protocol, ciphers=self.ciphers, |             ssl_version=self.protocol, ciphers=self.ciphers, | ||||||
|             do_handshake_on_connect=False) |             do_handshake_on_connect=False) | ||||||
|  |  | ||||||
|         self.server_bind() |         if bind_and_activate: | ||||||
|         self.server_activate() |             try: | ||||||
|  |                 self.server_bind() | ||||||
|  |                 self.server_activate() | ||||||
|  |             except BaseException: | ||||||
|  |                 self.server_close() | ||||||
|  |                 raise | ||||||
|  |  | ||||||
|  |     def finish_request(self, request, client_address): | ||||||
| class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer): |  | ||||||
|     def process_request_thread(self, request, client_address): |  | ||||||
|         with self.connections_guard: |  | ||||||
|             return super().process_request_thread(request, client_address) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer): |  | ||||||
|     def process_request_thread(self, request, client_address): |  | ||||||
|         try: |         try: | ||||||
|             try: |             try: | ||||||
|                 request.do_handshake() |                 request.do_handshake() | ||||||
| @@ -135,8 +142,7 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer): | |||||||
|             finally: |             finally: | ||||||
|                 self.shutdown_request(request) |                 self.shutdown_request(request) | ||||||
|             return |             return | ||||||
|         with self.connections_guard: |         return super().finish_request(request, client_address) | ||||||
|             return super().process_request_thread(request, client_address) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ServerHandler(wsgiref.simple_server.ServerHandler): | class ServerHandler(wsgiref.simple_server.ServerHandler): | ||||||
| @@ -197,7 +203,7 @@ def serve(configuration): | |||||||
|     # Create collection servers |     # Create collection servers | ||||||
|     servers = {} |     servers = {} | ||||||
|     if configuration.getboolean("server", "ssl"): |     if configuration.getboolean("server", "ssl"): | ||||||
|         server_class = ThreadedHTTPSServer |         server_class = ParallelHTTPSServer | ||||||
|         server_class.certificate = configuration.get("server", "certificate") |         server_class.certificate = configuration.get("server", "certificate") | ||||||
|         server_class.key = configuration.get("server", "key") |         server_class.key = configuration.get("server", "key") | ||||||
|         server_class.certificate_authority = configuration.get( |         server_class.certificate_authority = configuration.get( | ||||||
| @@ -216,7 +222,7 @@ def serve(configuration): | |||||||
|                 raise RuntimeError("Failed to read SSL %s %r: %s" % |                 raise RuntimeError("Failed to read SSL %s %r: %s" % | ||||||
|                                    (name, filename, e)) from e |                                    (name, filename, e)) from e | ||||||
|     else: |     else: | ||||||
|         server_class = ThreadedHTTPServer |         server_class = ParallelHTTPServer | ||||||
|     server_class.client_timeout = configuration.getint("server", "timeout") |     server_class.client_timeout = configuration.getint("server", "timeout") | ||||||
|     server_class.max_connections = configuration.getint( |     server_class.max_connections = configuration.getint( | ||||||
|         "server", "max_connections") |         "server", "max_connections") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Unrud
					Unrud