Use forking for internal server when available
This commit is contained in:
parent
ddd99a5329
commit
30a9ecc06b
@ -25,6 +25,7 @@ http://docs.python.org/library/logging.config.html
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
@ -35,7 +36,7 @@ except ImportError:
|
||||
journal = None
|
||||
|
||||
LOGGER_NAME = "radicale"
|
||||
LOGGER_FORMAT = "[%(processName)s/%(threadName)s] %(levelname)s: %(message)s"
|
||||
LOGGER_FORMAT = "[%(ident)s] %(levelname)s: %(message)s"
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
logger = logging.getLogger(LOGGER_NAME)
|
||||
@ -50,6 +51,27 @@ class RemoveTracebackFilter(logging.Filter):
|
||||
removeTracebackFilter = RemoveTracebackFilter()
|
||||
|
||||
|
||||
class IdentLogRecordFactory:
|
||||
"""LogRecordFactory that adds ``ident`` attribute."""
|
||||
|
||||
def __init__(self, upstream_factory):
|
||||
self.upstream_factory = upstream_factory
|
||||
self.main_pid = os.getpid()
|
||||
self.main_thread_name = threading.current_thread().name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
record = self.upstream_factory(*args, **kwargs)
|
||||
pid = os.getpid()
|
||||
thread_name = threading.current_thread().name
|
||||
ident = "%x" % self.main_pid
|
||||
if pid != self.main_pid:
|
||||
ident += "%+x" % (pid - self.main_pid)
|
||||
if thread_name != self.main_thread_name:
|
||||
ident += "/%s" % thread_name
|
||||
record.ident = ident
|
||||
return record
|
||||
|
||||
|
||||
class ThreadStreamsHandler(logging.Handler):
|
||||
|
||||
terminator = "\n"
|
||||
@ -60,6 +82,9 @@ class ThreadStreamsHandler(logging.Handler):
|
||||
self.fallback_stream = fallback_stream
|
||||
self.fallback_handler = fallback_handler
|
||||
|
||||
def createLock(self):
|
||||
self.lock = multiprocessing.Lock()
|
||||
|
||||
def setFormatter(self, form):
|
||||
super().setFormatter(form)
|
||||
self.fallback_handler.setFormatter(form)
|
||||
@ -116,6 +141,8 @@ def setup():
|
||||
handler = ThreadStreamsHandler(sys.stderr, get_default_handler())
|
||||
logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler])
|
||||
register_stream = handler.register_stream
|
||||
log_record_factory = IdentLogRecordFactory(logging.getLogRecordFactory())
|
||||
logging.setLogRecordFactory(log_record_factory)
|
||||
set_level(logging.DEBUG)
|
||||
|
||||
|
||||
|
@ -22,6 +22,7 @@ Radicale WSGI server.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
@ -29,16 +30,20 @@ import socket
|
||||
import socketserver
|
||||
import ssl
|
||||
import sys
|
||||
import threading
|
||||
import wsgiref.simple_server
|
||||
from urllib.parse import unquote
|
||||
|
||||
from radicale import Application
|
||||
from radicale.log import logger
|
||||
|
||||
if hasattr(socketserver, "ForkingMixIn"):
|
||||
ParallelizationMixIn = socketserver.ForkingMixIn
|
||||
else:
|
||||
ParallelizationMixIn = socketserver.ThreadingMixIn
|
||||
|
||||
class HTTPServer(wsgiref.simple_server.WSGIServer):
|
||||
"""HTTP server."""
|
||||
|
||||
class ParallelHTTPServer(ParallelizationMixIn,
|
||||
wsgiref.simple_server.WSGIServer):
|
||||
|
||||
# These class attributes must be set before creating instance
|
||||
client_timeout = None
|
||||
@ -59,7 +64,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
|
||||
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
|
||||
|
||||
if self.max_connections:
|
||||
self.connections_guard = threading.BoundedSemaphore(
|
||||
self.connections_guard = multiprocessing.BoundedSemaphore(
|
||||
self.max_connections)
|
||||
else:
|
||||
# use dummy context manager
|
||||
@ -75,10 +80,14 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
|
||||
|
||||
def get_request(self):
|
||||
# Set timeout for client
|
||||
_socket, address = super().get_request()
|
||||
socket_, address = super().get_request()
|
||||
if self.client_timeout:
|
||||
_socket.settimeout(self.client_timeout)
|
||||
return _socket, address
|
||||
socket_.settimeout(self.client_timeout)
|
||||
return socket_, address
|
||||
|
||||
def finish_request(self, request, client_address):
|
||||
with self.connections_guard:
|
||||
return super().finish_request(request, client_address)
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
if issubclass(sys.exc_info()[0], socket.timeout):
|
||||
@ -88,8 +97,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
|
||||
sys.exc_info()[1], exc_info=True)
|
||||
|
||||
|
||||
class HTTPSServer(HTTPServer):
|
||||
"""HTTPS server."""
|
||||
class ParallelHTTPSServer(ParallelHTTPServer):
|
||||
|
||||
# These class attributes must be set before creating instance
|
||||
certificate = None
|
||||
@ -98,9 +106,11 @@ class HTTPSServer(HTTPServer):
|
||||
ciphers = None
|
||||
certificate_authority = None
|
||||
|
||||
def __init__(self, address, handler):
|
||||
def __init__(self, address, handler, bind_and_activate=True):
|
||||
"""Create server by wrapping HTTP socket in an SSL socket."""
|
||||
super().__init__(address, handler, bind_and_activate=False)
|
||||
|
||||
# Do not bind and activate, as we change the socket
|
||||
super().__init__(address, handler, False)
|
||||
|
||||
self.socket = ssl.wrap_socket(
|
||||
self.socket, self.key, self.certificate, server_side=True,
|
||||
@ -110,18 +120,15 @@ class HTTPSServer(HTTPServer):
|
||||
ssl_version=self.protocol, ciphers=self.ciphers,
|
||||
do_handshake_on_connect=False)
|
||||
|
||||
self.server_bind()
|
||||
self.server_activate()
|
||||
if bind_and_activate:
|
||||
try:
|
||||
self.server_bind()
|
||||
self.server_activate()
|
||||
except BaseException:
|
||||
self.server_close()
|
||||
raise
|
||||
|
||||
|
||||
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
|
||||
def process_request_thread(self, request, client_address):
|
||||
with self.connections_guard:
|
||||
return super().process_request_thread(request, client_address)
|
||||
|
||||
|
||||
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
||||
def process_request_thread(self, request, client_address):
|
||||
def finish_request(self, request, client_address):
|
||||
try:
|
||||
try:
|
||||
request.do_handshake()
|
||||
@ -135,8 +142,7 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
||||
finally:
|
||||
self.shutdown_request(request)
|
||||
return
|
||||
with self.connections_guard:
|
||||
return super().process_request_thread(request, client_address)
|
||||
return super().finish_request(request, client_address)
|
||||
|
||||
|
||||
class ServerHandler(wsgiref.simple_server.ServerHandler):
|
||||
@ -197,7 +203,7 @@ def serve(configuration):
|
||||
# Create collection servers
|
||||
servers = {}
|
||||
if configuration.getboolean("server", "ssl"):
|
||||
server_class = ThreadedHTTPSServer
|
||||
server_class = ParallelHTTPSServer
|
||||
server_class.certificate = configuration.get("server", "certificate")
|
||||
server_class.key = configuration.get("server", "key")
|
||||
server_class.certificate_authority = configuration.get(
|
||||
@ -216,7 +222,7 @@ def serve(configuration):
|
||||
raise RuntimeError("Failed to read SSL %s %r: %s" %
|
||||
(name, filename, e)) from e
|
||||
else:
|
||||
server_class = ThreadedHTTPServer
|
||||
server_class = ParallelHTTPServer
|
||||
server_class.client_timeout = configuration.getint("server", "timeout")
|
||||
server_class.max_connections = configuration.getint(
|
||||
"server", "max_connections")
|
||||
|
Loading…
Reference in New Issue
Block a user