diff --git a/radicale/__init__.py b/radicale/__init__.py index 5251bfc..f679304 100644 --- a/radicale/__init__.py +++ b/radicale/__init__.py @@ -30,6 +30,7 @@ import os import pprint import base64 import socket +import socketserver import ssl import wsgiref.simple_server import re @@ -93,6 +94,14 @@ class HTTPSServer(HTTPServer): self.server_activate() +class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer): + pass + + +class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer): + pass + + class RequestHandler(wsgiref.simple_server.WSGIRequestHandler): """HTTP requests handler.""" def log_message(self, *args, **kwargs): @@ -274,14 +283,6 @@ class Application: is_authenticated = self.is_authenticated(user, password) is_valid_user = is_authenticated or not user - if is_valid_user: - items = self.Collection.discover( - path, environ.get("HTTP_DEPTH", "0")) - read_allowed_items, write_allowed_items = ( - self.collect_allowed_items(items, user)) - else: - read_allowed_items, write_allowed_items = None, None - # Get content content_length = int(environ.get("CONTENT_LENGTH") or 0) if content_length: @@ -291,13 +292,26 @@ class Application: else: content = None - if is_valid_user and ( - (read_allowed_items or write_allowed_items) or - (is_authenticated and function == self.do_PROPFIND) or - function == self.do_OPTIONS): - status, headers, answer = function( - environ, read_allowed_items, write_allowed_items, content, - user) + if is_valid_user: + if function in (self.do_GET, self.do_HEAD, + self.do_OPTIONS, self.do_PROPFIND, + self.do_REPORT): + lock_mode = "r" + else: + lock_mode = "w" + with self.Collection.acquire_lock(lock_mode): + items = self.Collection.discover( + path, environ.get("HTTP_DEPTH", "0")) + read_allowed_items, write_allowed_items = ( + self.collect_allowed_items(items, user)) + if (read_allowed_items or write_allowed_items or + is_authenticated and function == self.do_PROPFIND or + function == self.do_OPTIONS): + status, headers, answer = function( + environ, read_allowed_items, write_allowed_items, + content, user) + else: + status, headers, answer = NOT_ALLOWED else: status, headers, answer = NOT_ALLOWED diff --git a/radicale/__main__.py b/radicale/__main__.py index 9672214..cda628b 100644 --- a/radicale/__main__.py +++ b/radicale/__main__.py @@ -33,7 +33,8 @@ import ssl from wsgiref.simple_server import make_server from . import ( - Application, config, HTTPServer, HTTPSServer, log, RequestHandler, VERSION) + Application, config, ThreadedHTTPServer, ThreadedHTTPSServer, log, + RequestHandler, VERSION) # This is a script, many branches and variables @@ -152,7 +153,7 @@ def run(): # Create collection servers servers = {} if configuration.getboolean("server", "ssl"): - server_class = HTTPSServer + server_class = ThreadedHTTPSServer server_class.certificate = configuration.get("server", "certificate") server_class.key = configuration.get("server", "key") server_class.cyphers = configuration.get("server", "cyphers") @@ -168,7 +169,7 @@ def run(): "Error while reading SSL %s %r: %s" % ( name, filename, exception)) else: - server_class = HTTPServer + server_class = ThreadedHTTPServer if not configuration.getboolean("server", "dns_lookup"): RequestHandler.address_string = lambda self: self.client_address[0] diff --git a/radicale/storage.py b/radicale/storage.py index d68431e..fec418b 100644 --- a/radicale/storage.py +++ b/radicale/storage.py @@ -29,6 +29,8 @@ import json import os import posixpath import shutil +import stat +import threading import time from contextlib import contextmanager from hashlib import md5 @@ -37,6 +39,42 @@ from uuid import uuid4 import vobject +if os.name == "nt": + import ctypes + import ctypes.wintypes + import msvcrt + + LOCKFILE_EXCLUSIVE_LOCK = 2 + if ctypes.sizeof(ctypes.c_void_p) == 4: + ULONG_PTR = ctypes.c_uint32 + else: + ULONG_PTR = ctypes.c_uint64 + + class Overlapped(ctypes.Structure): + _fields_ = [("internal", ULONG_PTR), + ("internal_high", ULONG_PTR), + ("offset", ctypes.wintypes.DWORD), + ("offset_high", ctypes.wintypes.DWORD), + ("h_event", ctypes.wintypes.HANDLE)] + + lock_file_ex = ctypes.windll.kernel32.LockFileEx + lock_file_ex.argtypes = [ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ctypes.POINTER(Overlapped)] + lock_file_ex.restype = ctypes.wintypes.BOOL + unlock_file_ex = ctypes.windll.kernel32.UnlockFileEx + unlock_file_ex.argtypes = [ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ctypes.POINTER(Overlapped)] + unlock_file_ex.restype = ctypes.wintypes.BOOL +elif os.name == "posix": + import fcntl + def load(configuration, logger): """Load the storage manager chosen in configuration.""" @@ -54,6 +92,7 @@ def load(configuration, logger): MIMETYPES = {"VADDRESSBOOK": "text/vcard", "VCALENDAR": "text/calendar"} +MAX_FILE_LOCK_DURATION = 0.25 def get_etag(text): @@ -246,6 +285,17 @@ class BaseCollection: """Get the unicode string representing the whole collection.""" raise NotImplementedError + @classmethod + @contextmanager + def acquire_lock(cls, mode): + """Set a context manager to lock the whole storage. + + ``mode`` must either be "r" for shared access or "w" for exclusive + access. + + """ + raise NotImplementedError + class Collection(BaseCollection): """Collection stored in several files per calendar.""" @@ -475,3 +525,95 @@ class Collection(BaseCollection): elif self.get_meta("tag") == "VADDRESSBOOK": return "".join([item.serialize() for item in items]) return "" + + _lock = threading.Lock() + _waiters = [] + _lock_file = None + _lock_file_locked = False + _lock_file_time = 0 + _readers = 0 + _writer = False + + @classmethod + @contextmanager + def acquire_lock(cls, mode): + def condition(): + # Prevent starvation of writers in other processes + if cls._lock_file_locked: + time_delta = time.time() - cls._lock_file_time + if time_delta < 0 or time_delta > MAX_FILE_LOCK_DURATION: + return False + if mode == "r": + return not cls._writer + else: + return not cls._writer and cls._readers == 0 + + if mode not in ("r", "w"): + raise ValueError("Invalid lock mode: %s" % mode) + # Use a primitive lock which only works within one process as a + # precondition for inter-process file-based locking + with cls._lock: + if cls._waiters or not condition(): + # use FIFO for access requests + waiter = threading.Condition(lock=cls._lock) + cls._waiters.append(waiter) + while True: + waiter.wait() + if condition(): + break + cls._waiters.pop(0) + if mode == "r": + cls._readers += 1 + # notify additional potential readers + if cls._waiters: + cls._waiters[0].notify() + else: + cls._writer = True + if not cls._lock_file: + folder = os.path.expanduser( + cls.configuration.get("storage", "filesystem_folder")) + if not os.path.exists(folder): + os.makedirs(folder, exist_ok=True) + lock_path = os.path.join(folder, "Radicale.lock") + cls._lock_file = open(lock_path, "w+") + # set access rights to a necessary minimum to prevent locking + # by arbitrary users + try: + os.chmod(lock_path, stat.S_IWUSR | stat.S_IRUSR) + except OSError: + cls.logger.debug("Failed to set permissions on lock file") + if not cls._lock_file_locked: + if os.name == "nt": + handle = msvcrt.get_osfhandle(cls._lock_file.fileno()) + flags = LOCKFILE_EXCLUSIVE_LOCK if mode == "w" else 0 + overlapped = Overlapped() + if not lock_file_ex(handle, flags, 0, 1, 0, overlapped): + cls.logger.debug("Locking not supported") + elif os.name == "posix": + _cmd = fcntl.LOCK_EX if mode == "w" else fcntl.LOCK_SH + try: + fcntl.lockf(cls._lock_file.fileno(), _cmd) + except OSError: + cls.logger.debug("Locking not supported") + cls._lock_file_locked = True + cls._lock_file_time = time.time() + yield + with cls._lock: + if mode == "r": + cls._readers -= 1 + else: + cls._writer = False + if cls._readers == 0: + if os.name == "nt": + handle = msvcrt.get_osfhandle(cls._lock_file.fileno()) + overlapped = Overlapped() + if not unlock_file_ex(handle, 0, 1, 0, overlapped): + cls.logger.debug("Unlocking not supported") + elif os.name == "posix": + try: + fcntl.lockf(cls._lock_file.fileno(), fcntl.LOCK_UN) + except OSError: + cls.logger.debug("Unlocking not supported") + cls._lock_file_locked = False + if cls._waiters: + cls._waiters[0].notify()