Merge pull request #402 from Unrud/locking
Implement locking of whole storage
This commit is contained in:
commit
a3c32ee77f
@ -30,6 +30,7 @@ import os
|
|||||||
import pprint
|
import pprint
|
||||||
import base64
|
import base64
|
||||||
import socket
|
import socket
|
||||||
|
import socketserver
|
||||||
import ssl
|
import ssl
|
||||||
import wsgiref.simple_server
|
import wsgiref.simple_server
|
||||||
import re
|
import re
|
||||||
@ -93,6 +94,14 @@ class HTTPSServer(HTTPServer):
|
|||||||
self.server_activate()
|
self.server_activate()
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
||||||
"""HTTP requests handler."""
|
"""HTTP requests handler."""
|
||||||
def log_message(self, *args, **kwargs):
|
def log_message(self, *args, **kwargs):
|
||||||
@ -274,14 +283,6 @@ class Application:
|
|||||||
is_authenticated = self.is_authenticated(user, password)
|
is_authenticated = self.is_authenticated(user, password)
|
||||||
is_valid_user = is_authenticated or not user
|
is_valid_user = is_authenticated or not user
|
||||||
|
|
||||||
if is_valid_user:
|
|
||||||
items = self.Collection.discover(
|
|
||||||
path, environ.get("HTTP_DEPTH", "0"))
|
|
||||||
read_allowed_items, write_allowed_items = (
|
|
||||||
self.collect_allowed_items(items, user))
|
|
||||||
else:
|
|
||||||
read_allowed_items, write_allowed_items = None, None
|
|
||||||
|
|
||||||
# Get content
|
# Get content
|
||||||
content_length = int(environ.get("CONTENT_LENGTH") or 0)
|
content_length = int(environ.get("CONTENT_LENGTH") or 0)
|
||||||
if content_length:
|
if content_length:
|
||||||
@ -291,13 +292,26 @@ class Application:
|
|||||||
else:
|
else:
|
||||||
content = None
|
content = None
|
||||||
|
|
||||||
if is_valid_user and (
|
if is_valid_user:
|
||||||
(read_allowed_items or write_allowed_items) or
|
if function in (self.do_GET, self.do_HEAD,
|
||||||
(is_authenticated and function == self.do_PROPFIND) or
|
self.do_OPTIONS, self.do_PROPFIND,
|
||||||
function == self.do_OPTIONS):
|
self.do_REPORT):
|
||||||
status, headers, answer = function(
|
lock_mode = "r"
|
||||||
environ, read_allowed_items, write_allowed_items, content,
|
else:
|
||||||
user)
|
lock_mode = "w"
|
||||||
|
with self.Collection.acquire_lock(lock_mode):
|
||||||
|
items = self.Collection.discover(
|
||||||
|
path, environ.get("HTTP_DEPTH", "0"))
|
||||||
|
read_allowed_items, write_allowed_items = (
|
||||||
|
self.collect_allowed_items(items, user))
|
||||||
|
if (read_allowed_items or write_allowed_items or
|
||||||
|
is_authenticated and function == self.do_PROPFIND or
|
||||||
|
function == self.do_OPTIONS):
|
||||||
|
status, headers, answer = function(
|
||||||
|
environ, read_allowed_items, write_allowed_items,
|
||||||
|
content, user)
|
||||||
|
else:
|
||||||
|
status, headers, answer = NOT_ALLOWED
|
||||||
else:
|
else:
|
||||||
status, headers, answer = NOT_ALLOWED
|
status, headers, answer = NOT_ALLOWED
|
||||||
|
|
||||||
|
@ -33,7 +33,8 @@ import ssl
|
|||||||
from wsgiref.simple_server import make_server
|
from wsgiref.simple_server import make_server
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
Application, config, HTTPServer, HTTPSServer, log, RequestHandler, VERSION)
|
Application, config, ThreadedHTTPServer, ThreadedHTTPSServer, log,
|
||||||
|
RequestHandler, VERSION)
|
||||||
|
|
||||||
|
|
||||||
# This is a script, many branches and variables
|
# This is a script, many branches and variables
|
||||||
@ -152,7 +153,7 @@ def run():
|
|||||||
# Create collection servers
|
# Create collection servers
|
||||||
servers = {}
|
servers = {}
|
||||||
if configuration.getboolean("server", "ssl"):
|
if configuration.getboolean("server", "ssl"):
|
||||||
server_class = HTTPSServer
|
server_class = ThreadedHTTPSServer
|
||||||
server_class.certificate = configuration.get("server", "certificate")
|
server_class.certificate = configuration.get("server", "certificate")
|
||||||
server_class.key = configuration.get("server", "key")
|
server_class.key = configuration.get("server", "key")
|
||||||
server_class.cyphers = configuration.get("server", "cyphers")
|
server_class.cyphers = configuration.get("server", "cyphers")
|
||||||
@ -168,7 +169,7 @@ def run():
|
|||||||
"Error while reading SSL %s %r: %s" % (
|
"Error while reading SSL %s %r: %s" % (
|
||||||
name, filename, exception))
|
name, filename, exception))
|
||||||
else:
|
else:
|
||||||
server_class = HTTPServer
|
server_class = ThreadedHTTPServer
|
||||||
|
|
||||||
if not configuration.getboolean("server", "dns_lookup"):
|
if not configuration.getboolean("server", "dns_lookup"):
|
||||||
RequestHandler.address_string = lambda self: self.client_address[0]
|
RequestHandler.address_string = lambda self: self.client_address[0]
|
||||||
|
@ -29,6 +29,8 @@ import json
|
|||||||
import os
|
import os
|
||||||
import posixpath
|
import posixpath
|
||||||
import shutil
|
import shutil
|
||||||
|
import stat
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
@ -37,6 +39,42 @@ from uuid import uuid4
|
|||||||
|
|
||||||
import vobject
|
import vobject
|
||||||
|
|
||||||
|
if os.name == "nt":
|
||||||
|
import ctypes
|
||||||
|
import ctypes.wintypes
|
||||||
|
import msvcrt
|
||||||
|
|
||||||
|
LOCKFILE_EXCLUSIVE_LOCK = 2
|
||||||
|
if ctypes.sizeof(ctypes.c_void_p) == 4:
|
||||||
|
ULONG_PTR = ctypes.c_uint32
|
||||||
|
else:
|
||||||
|
ULONG_PTR = ctypes.c_uint64
|
||||||
|
|
||||||
|
class Overlapped(ctypes.Structure):
|
||||||
|
_fields_ = [("internal", ULONG_PTR),
|
||||||
|
("internal_high", ULONG_PTR),
|
||||||
|
("offset", ctypes.wintypes.DWORD),
|
||||||
|
("offset_high", ctypes.wintypes.DWORD),
|
||||||
|
("h_event", ctypes.wintypes.HANDLE)]
|
||||||
|
|
||||||
|
lock_file_ex = ctypes.windll.kernel32.LockFileEx
|
||||||
|
lock_file_ex.argtypes = [ctypes.wintypes.HANDLE,
|
||||||
|
ctypes.wintypes.DWORD,
|
||||||
|
ctypes.wintypes.DWORD,
|
||||||
|
ctypes.wintypes.DWORD,
|
||||||
|
ctypes.wintypes.DWORD,
|
||||||
|
ctypes.POINTER(Overlapped)]
|
||||||
|
lock_file_ex.restype = ctypes.wintypes.BOOL
|
||||||
|
unlock_file_ex = ctypes.windll.kernel32.UnlockFileEx
|
||||||
|
unlock_file_ex.argtypes = [ctypes.wintypes.HANDLE,
|
||||||
|
ctypes.wintypes.DWORD,
|
||||||
|
ctypes.wintypes.DWORD,
|
||||||
|
ctypes.wintypes.DWORD,
|
||||||
|
ctypes.POINTER(Overlapped)]
|
||||||
|
unlock_file_ex.restype = ctypes.wintypes.BOOL
|
||||||
|
elif os.name == "posix":
|
||||||
|
import fcntl
|
||||||
|
|
||||||
|
|
||||||
def load(configuration, logger):
|
def load(configuration, logger):
|
||||||
"""Load the storage manager chosen in configuration."""
|
"""Load the storage manager chosen in configuration."""
|
||||||
@ -54,6 +92,7 @@ def load(configuration, logger):
|
|||||||
|
|
||||||
|
|
||||||
MIMETYPES = {"VADDRESSBOOK": "text/vcard", "VCALENDAR": "text/calendar"}
|
MIMETYPES = {"VADDRESSBOOK": "text/vcard", "VCALENDAR": "text/calendar"}
|
||||||
|
MAX_FILE_LOCK_DURATION = 0.25
|
||||||
|
|
||||||
|
|
||||||
def get_etag(text):
|
def get_etag(text):
|
||||||
@ -246,6 +285,17 @@ class BaseCollection:
|
|||||||
"""Get the unicode string representing the whole collection."""
|
"""Get the unicode string representing the whole collection."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def acquire_lock(cls, mode):
|
||||||
|
"""Set a context manager to lock the whole storage.
|
||||||
|
|
||||||
|
``mode`` must either be "r" for shared access or "w" for exclusive
|
||||||
|
access.
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class Collection(BaseCollection):
|
class Collection(BaseCollection):
|
||||||
"""Collection stored in several files per calendar."""
|
"""Collection stored in several files per calendar."""
|
||||||
@ -475,3 +525,95 @@ class Collection(BaseCollection):
|
|||||||
elif self.get_meta("tag") == "VADDRESSBOOK":
|
elif self.get_meta("tag") == "VADDRESSBOOK":
|
||||||
return "".join([item.serialize() for item in items])
|
return "".join([item.serialize() for item in items])
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_waiters = []
|
||||||
|
_lock_file = None
|
||||||
|
_lock_file_locked = False
|
||||||
|
_lock_file_time = 0
|
||||||
|
_readers = 0
|
||||||
|
_writer = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def acquire_lock(cls, mode):
|
||||||
|
def condition():
|
||||||
|
# Prevent starvation of writers in other processes
|
||||||
|
if cls._lock_file_locked:
|
||||||
|
time_delta = time.time() - cls._lock_file_time
|
||||||
|
if time_delta < 0 or time_delta > MAX_FILE_LOCK_DURATION:
|
||||||
|
return False
|
||||||
|
if mode == "r":
|
||||||
|
return not cls._writer
|
||||||
|
else:
|
||||||
|
return not cls._writer and cls._readers == 0
|
||||||
|
|
||||||
|
if mode not in ("r", "w"):
|
||||||
|
raise ValueError("Invalid lock mode: %s" % mode)
|
||||||
|
# Use a primitive lock which only works within one process as a
|
||||||
|
# precondition for inter-process file-based locking
|
||||||
|
with cls._lock:
|
||||||
|
if cls._waiters or not condition():
|
||||||
|
# use FIFO for access requests
|
||||||
|
waiter = threading.Condition(lock=cls._lock)
|
||||||
|
cls._waiters.append(waiter)
|
||||||
|
while True:
|
||||||
|
waiter.wait()
|
||||||
|
if condition():
|
||||||
|
break
|
||||||
|
cls._waiters.pop(0)
|
||||||
|
if mode == "r":
|
||||||
|
cls._readers += 1
|
||||||
|
# notify additional potential readers
|
||||||
|
if cls._waiters:
|
||||||
|
cls._waiters[0].notify()
|
||||||
|
else:
|
||||||
|
cls._writer = True
|
||||||
|
if not cls._lock_file:
|
||||||
|
folder = os.path.expanduser(
|
||||||
|
cls.configuration.get("storage", "filesystem_folder"))
|
||||||
|
if not os.path.exists(folder):
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
lock_path = os.path.join(folder, "Radicale.lock")
|
||||||
|
cls._lock_file = open(lock_path, "w+")
|
||||||
|
# set access rights to a necessary minimum to prevent locking
|
||||||
|
# by arbitrary users
|
||||||
|
try:
|
||||||
|
os.chmod(lock_path, stat.S_IWUSR | stat.S_IRUSR)
|
||||||
|
except OSError:
|
||||||
|
cls.logger.debug("Failed to set permissions on lock file")
|
||||||
|
if not cls._lock_file_locked:
|
||||||
|
if os.name == "nt":
|
||||||
|
handle = msvcrt.get_osfhandle(cls._lock_file.fileno())
|
||||||
|
flags = LOCKFILE_EXCLUSIVE_LOCK if mode == "w" else 0
|
||||||
|
overlapped = Overlapped()
|
||||||
|
if not lock_file_ex(handle, flags, 0, 1, 0, overlapped):
|
||||||
|
cls.logger.debug("Locking not supported")
|
||||||
|
elif os.name == "posix":
|
||||||
|
_cmd = fcntl.LOCK_EX if mode == "w" else fcntl.LOCK_SH
|
||||||
|
try:
|
||||||
|
fcntl.lockf(cls._lock_file.fileno(), _cmd)
|
||||||
|
except OSError:
|
||||||
|
cls.logger.debug("Locking not supported")
|
||||||
|
cls._lock_file_locked = True
|
||||||
|
cls._lock_file_time = time.time()
|
||||||
|
yield
|
||||||
|
with cls._lock:
|
||||||
|
if mode == "r":
|
||||||
|
cls._readers -= 1
|
||||||
|
else:
|
||||||
|
cls._writer = False
|
||||||
|
if cls._readers == 0:
|
||||||
|
if os.name == "nt":
|
||||||
|
handle = msvcrt.get_osfhandle(cls._lock_file.fileno())
|
||||||
|
overlapped = Overlapped()
|
||||||
|
if not unlock_file_ex(handle, 0, 1, 0, overlapped):
|
||||||
|
cls.logger.debug("Unlocking not supported")
|
||||||
|
elif os.name == "posix":
|
||||||
|
try:
|
||||||
|
fcntl.lockf(cls._lock_file.fileno(), fcntl.LOCK_UN)
|
||||||
|
except OSError:
|
||||||
|
cls.logger.debug("Unlocking not supported")
|
||||||
|
cls._lock_file_locked = False
|
||||||
|
if cls._waiters:
|
||||||
|
cls._waiters[0].notify()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user