Use forking for internal server when available
This commit is contained in:
parent
ddd99a5329
commit
30a9ecc06b
@ -25,6 +25,7 @@ http://docs.python.org/library/logging.config.html
|
|||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
@ -35,7 +36,7 @@ except ImportError:
|
|||||||
journal = None
|
journal = None
|
||||||
|
|
||||||
LOGGER_NAME = "radicale"
|
LOGGER_NAME = "radicale"
|
||||||
LOGGER_FORMAT = "[%(processName)s/%(threadName)s] %(levelname)s: %(message)s"
|
LOGGER_FORMAT = "[%(ident)s] %(levelname)s: %(message)s"
|
||||||
|
|
||||||
root_logger = logging.getLogger()
|
root_logger = logging.getLogger()
|
||||||
logger = logging.getLogger(LOGGER_NAME)
|
logger = logging.getLogger(LOGGER_NAME)
|
||||||
@ -50,6 +51,27 @@ class RemoveTracebackFilter(logging.Filter):
|
|||||||
removeTracebackFilter = RemoveTracebackFilter()
|
removeTracebackFilter = RemoveTracebackFilter()
|
||||||
|
|
||||||
|
|
||||||
|
class IdentLogRecordFactory:
|
||||||
|
"""LogRecordFactory that adds ``ident`` attribute."""
|
||||||
|
|
||||||
|
def __init__(self, upstream_factory):
|
||||||
|
self.upstream_factory = upstream_factory
|
||||||
|
self.main_pid = os.getpid()
|
||||||
|
self.main_thread_name = threading.current_thread().name
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
record = self.upstream_factory(*args, **kwargs)
|
||||||
|
pid = os.getpid()
|
||||||
|
thread_name = threading.current_thread().name
|
||||||
|
ident = "%x" % self.main_pid
|
||||||
|
if pid != self.main_pid:
|
||||||
|
ident += "%+x" % (pid - self.main_pid)
|
||||||
|
if thread_name != self.main_thread_name:
|
||||||
|
ident += "/%s" % thread_name
|
||||||
|
record.ident = ident
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
class ThreadStreamsHandler(logging.Handler):
|
class ThreadStreamsHandler(logging.Handler):
|
||||||
|
|
||||||
terminator = "\n"
|
terminator = "\n"
|
||||||
@ -60,6 +82,9 @@ class ThreadStreamsHandler(logging.Handler):
|
|||||||
self.fallback_stream = fallback_stream
|
self.fallback_stream = fallback_stream
|
||||||
self.fallback_handler = fallback_handler
|
self.fallback_handler = fallback_handler
|
||||||
|
|
||||||
|
def createLock(self):
|
||||||
|
self.lock = multiprocessing.Lock()
|
||||||
|
|
||||||
def setFormatter(self, form):
|
def setFormatter(self, form):
|
||||||
super().setFormatter(form)
|
super().setFormatter(form)
|
||||||
self.fallback_handler.setFormatter(form)
|
self.fallback_handler.setFormatter(form)
|
||||||
@ -116,6 +141,8 @@ def setup():
|
|||||||
handler = ThreadStreamsHandler(sys.stderr, get_default_handler())
|
handler = ThreadStreamsHandler(sys.stderr, get_default_handler())
|
||||||
logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler])
|
logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler])
|
||||||
register_stream = handler.register_stream
|
register_stream = handler.register_stream
|
||||||
|
log_record_factory = IdentLogRecordFactory(logging.getLogRecordFactory())
|
||||||
|
logging.setLogRecordFactory(log_record_factory)
|
||||||
set_level(logging.DEBUG)
|
set_level(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ Radicale WSGI server.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import select
|
import select
|
||||||
import signal
|
import signal
|
||||||
@ -29,16 +30,20 @@ import socket
|
|||||||
import socketserver
|
import socketserver
|
||||||
import ssl
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import wsgiref.simple_server
|
import wsgiref.simple_server
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from radicale import Application
|
from radicale import Application
|
||||||
from radicale.log import logger
|
from radicale.log import logger
|
||||||
|
|
||||||
|
if hasattr(socketserver, "ForkingMixIn"):
|
||||||
|
ParallelizationMixIn = socketserver.ForkingMixIn
|
||||||
|
else:
|
||||||
|
ParallelizationMixIn = socketserver.ThreadingMixIn
|
||||||
|
|
||||||
class HTTPServer(wsgiref.simple_server.WSGIServer):
|
|
||||||
"""HTTP server."""
|
class ParallelHTTPServer(ParallelizationMixIn,
|
||||||
|
wsgiref.simple_server.WSGIServer):
|
||||||
|
|
||||||
# These class attributes must be set before creating instance
|
# These class attributes must be set before creating instance
|
||||||
client_timeout = None
|
client_timeout = None
|
||||||
@ -59,7 +64,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
|
|||||||
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
|
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
|
||||||
|
|
||||||
if self.max_connections:
|
if self.max_connections:
|
||||||
self.connections_guard = threading.BoundedSemaphore(
|
self.connections_guard = multiprocessing.BoundedSemaphore(
|
||||||
self.max_connections)
|
self.max_connections)
|
||||||
else:
|
else:
|
||||||
# use dummy context manager
|
# use dummy context manager
|
||||||
@ -75,10 +80,14 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
|
|||||||
|
|
||||||
def get_request(self):
|
def get_request(self):
|
||||||
# Set timeout for client
|
# Set timeout for client
|
||||||
_socket, address = super().get_request()
|
socket_, address = super().get_request()
|
||||||
if self.client_timeout:
|
if self.client_timeout:
|
||||||
_socket.settimeout(self.client_timeout)
|
socket_.settimeout(self.client_timeout)
|
||||||
return _socket, address
|
return socket_, address
|
||||||
|
|
||||||
|
def finish_request(self, request, client_address):
|
||||||
|
with self.connections_guard:
|
||||||
|
return super().finish_request(request, client_address)
|
||||||
|
|
||||||
def handle_error(self, request, client_address):
|
def handle_error(self, request, client_address):
|
||||||
if issubclass(sys.exc_info()[0], socket.timeout):
|
if issubclass(sys.exc_info()[0], socket.timeout):
|
||||||
@ -88,8 +97,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
|
|||||||
sys.exc_info()[1], exc_info=True)
|
sys.exc_info()[1], exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
class HTTPSServer(HTTPServer):
|
class ParallelHTTPSServer(ParallelHTTPServer):
|
||||||
"""HTTPS server."""
|
|
||||||
|
|
||||||
# These class attributes must be set before creating instance
|
# These class attributes must be set before creating instance
|
||||||
certificate = None
|
certificate = None
|
||||||
@ -98,9 +106,11 @@ class HTTPSServer(HTTPServer):
|
|||||||
ciphers = None
|
ciphers = None
|
||||||
certificate_authority = None
|
certificate_authority = None
|
||||||
|
|
||||||
def __init__(self, address, handler):
|
def __init__(self, address, handler, bind_and_activate=True):
|
||||||
"""Create server by wrapping HTTP socket in an SSL socket."""
|
"""Create server by wrapping HTTP socket in an SSL socket."""
|
||||||
super().__init__(address, handler, bind_and_activate=False)
|
|
||||||
|
# Do not bind and activate, as we change the socket
|
||||||
|
super().__init__(address, handler, False)
|
||||||
|
|
||||||
self.socket = ssl.wrap_socket(
|
self.socket = ssl.wrap_socket(
|
||||||
self.socket, self.key, self.certificate, server_side=True,
|
self.socket, self.key, self.certificate, server_side=True,
|
||||||
@ -110,18 +120,15 @@ class HTTPSServer(HTTPServer):
|
|||||||
ssl_version=self.protocol, ciphers=self.ciphers,
|
ssl_version=self.protocol, ciphers=self.ciphers,
|
||||||
do_handshake_on_connect=False)
|
do_handshake_on_connect=False)
|
||||||
|
|
||||||
self.server_bind()
|
if bind_and_activate:
|
||||||
self.server_activate()
|
try:
|
||||||
|
self.server_bind()
|
||||||
|
self.server_activate()
|
||||||
|
except BaseException:
|
||||||
|
self.server_close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def finish_request(self, request, client_address):
|
||||||
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
|
|
||||||
def process_request_thread(self, request, client_address):
|
|
||||||
with self.connections_guard:
|
|
||||||
return super().process_request_thread(request, client_address)
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
|
||||||
def process_request_thread(self, request, client_address):
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
request.do_handshake()
|
request.do_handshake()
|
||||||
@ -135,8 +142,7 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
|||||||
finally:
|
finally:
|
||||||
self.shutdown_request(request)
|
self.shutdown_request(request)
|
||||||
return
|
return
|
||||||
with self.connections_guard:
|
return super().finish_request(request, client_address)
|
||||||
return super().process_request_thread(request, client_address)
|
|
||||||
|
|
||||||
|
|
||||||
class ServerHandler(wsgiref.simple_server.ServerHandler):
|
class ServerHandler(wsgiref.simple_server.ServerHandler):
|
||||||
@ -197,7 +203,7 @@ def serve(configuration):
|
|||||||
# Create collection servers
|
# Create collection servers
|
||||||
servers = {}
|
servers = {}
|
||||||
if configuration.getboolean("server", "ssl"):
|
if configuration.getboolean("server", "ssl"):
|
||||||
server_class = ThreadedHTTPSServer
|
server_class = ParallelHTTPSServer
|
||||||
server_class.certificate = configuration.get("server", "certificate")
|
server_class.certificate = configuration.get("server", "certificate")
|
||||||
server_class.key = configuration.get("server", "key")
|
server_class.key = configuration.get("server", "key")
|
||||||
server_class.certificate_authority = configuration.get(
|
server_class.certificate_authority = configuration.get(
|
||||||
@ -216,7 +222,7 @@ def serve(configuration):
|
|||||||
raise RuntimeError("Failed to read SSL %s %r: %s" %
|
raise RuntimeError("Failed to read SSL %s %r: %s" %
|
||||||
(name, filename, e)) from e
|
(name, filename, e)) from e
|
||||||
else:
|
else:
|
||||||
server_class = ThreadedHTTPServer
|
server_class = ParallelHTTPServer
|
||||||
server_class.client_timeout = configuration.getint("server", "timeout")
|
server_class.client_timeout = configuration.getint("server", "timeout")
|
||||||
server_class.max_connections = configuration.getint(
|
server_class.max_connections = configuration.getint(
|
||||||
"server", "max_connections")
|
"server", "max_connections")
|
||||||
|
Loading…
Reference in New Issue
Block a user