Merge pull request #402 from Unrud/locking

Implement locking of whole storage
This commit is contained in:
Guillaume Ayoub 2016-05-25 14:03:48 +02:00
commit a3c32ee77f
3 changed files with 175 additions and 18 deletions

View File

@ -30,6 +30,7 @@ import os
import pprint import pprint
import base64 import base64
import socket import socket
import socketserver
import ssl import ssl
import wsgiref.simple_server import wsgiref.simple_server
import re import re
@ -93,6 +94,14 @@ class HTTPSServer(HTTPServer):
self.server_activate() self.server_activate()
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
pass
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
pass
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler): class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
"""HTTP requests handler.""" """HTTP requests handler."""
def log_message(self, *args, **kwargs): def log_message(self, *args, **kwargs):
@ -274,14 +283,6 @@ class Application:
is_authenticated = self.is_authenticated(user, password) is_authenticated = self.is_authenticated(user, password)
is_valid_user = is_authenticated or not user is_valid_user = is_authenticated or not user
if is_valid_user:
items = self.Collection.discover(
path, environ.get("HTTP_DEPTH", "0"))
read_allowed_items, write_allowed_items = (
self.collect_allowed_items(items, user))
else:
read_allowed_items, write_allowed_items = None, None
# Get content # Get content
content_length = int(environ.get("CONTENT_LENGTH") or 0) content_length = int(environ.get("CONTENT_LENGTH") or 0)
if content_length: if content_length:
@ -291,13 +292,26 @@ class Application:
else: else:
content = None content = None
if is_valid_user and ( if is_valid_user:
(read_allowed_items or write_allowed_items) or if function in (self.do_GET, self.do_HEAD,
(is_authenticated and function == self.do_PROPFIND) or self.do_OPTIONS, self.do_PROPFIND,
self.do_REPORT):
lock_mode = "r"
else:
lock_mode = "w"
with self.Collection.acquire_lock(lock_mode):
items = self.Collection.discover(
path, environ.get("HTTP_DEPTH", "0"))
read_allowed_items, write_allowed_items = (
self.collect_allowed_items(items, user))
if (read_allowed_items or write_allowed_items or
is_authenticated and function == self.do_PROPFIND or
function == self.do_OPTIONS): function == self.do_OPTIONS):
status, headers, answer = function( status, headers, answer = function(
environ, read_allowed_items, write_allowed_items, content, environ, read_allowed_items, write_allowed_items,
user) content, user)
else:
status, headers, answer = NOT_ALLOWED
else: else:
status, headers, answer = NOT_ALLOWED status, headers, answer = NOT_ALLOWED

View File

@ -33,7 +33,8 @@ import ssl
from wsgiref.simple_server import make_server from wsgiref.simple_server import make_server
from . import ( from . import (
Application, config, HTTPServer, HTTPSServer, log, RequestHandler, VERSION) Application, config, ThreadedHTTPServer, ThreadedHTTPSServer, log,
RequestHandler, VERSION)
# This is a script, many branches and variables # This is a script, many branches and variables
@ -152,7 +153,7 @@ def run():
# Create collection servers # Create collection servers
servers = {} servers = {}
if configuration.getboolean("server", "ssl"): if configuration.getboolean("server", "ssl"):
server_class = HTTPSServer server_class = ThreadedHTTPSServer
server_class.certificate = configuration.get("server", "certificate") server_class.certificate = configuration.get("server", "certificate")
server_class.key = configuration.get("server", "key") server_class.key = configuration.get("server", "key")
server_class.cyphers = configuration.get("server", "cyphers") server_class.cyphers = configuration.get("server", "cyphers")
@ -168,7 +169,7 @@ def run():
"Error while reading SSL %s %r: %s" % ( "Error while reading SSL %s %r: %s" % (
name, filename, exception)) name, filename, exception))
else: else:
server_class = HTTPServer server_class = ThreadedHTTPServer
if not configuration.getboolean("server", "dns_lookup"): if not configuration.getboolean("server", "dns_lookup"):
RequestHandler.address_string = lambda self: self.client_address[0] RequestHandler.address_string = lambda self: self.client_address[0]

View File

@ -29,6 +29,8 @@ import json
import os import os
import posixpath import posixpath
import shutil import shutil
import stat
import threading
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from hashlib import md5 from hashlib import md5
@ -37,6 +39,42 @@ from uuid import uuid4
import vobject import vobject
if os.name == "nt":
import ctypes
import ctypes.wintypes
import msvcrt
LOCKFILE_EXCLUSIVE_LOCK = 2
if ctypes.sizeof(ctypes.c_void_p) == 4:
ULONG_PTR = ctypes.c_uint32
else:
ULONG_PTR = ctypes.c_uint64
class Overlapped(ctypes.Structure):
_fields_ = [("internal", ULONG_PTR),
("internal_high", ULONG_PTR),
("offset", ctypes.wintypes.DWORD),
("offset_high", ctypes.wintypes.DWORD),
("h_event", ctypes.wintypes.HANDLE)]
lock_file_ex = ctypes.windll.kernel32.LockFileEx
lock_file_ex.argtypes = [ctypes.wintypes.HANDLE,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.POINTER(Overlapped)]
lock_file_ex.restype = ctypes.wintypes.BOOL
unlock_file_ex = ctypes.windll.kernel32.UnlockFileEx
unlock_file_ex.argtypes = [ctypes.wintypes.HANDLE,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.POINTER(Overlapped)]
unlock_file_ex.restype = ctypes.wintypes.BOOL
elif os.name == "posix":
import fcntl
def load(configuration, logger): def load(configuration, logger):
"""Load the storage manager chosen in configuration.""" """Load the storage manager chosen in configuration."""
@ -54,6 +92,7 @@ def load(configuration, logger):
MIMETYPES = {"VADDRESSBOOK": "text/vcard", "VCALENDAR": "text/calendar"} MIMETYPES = {"VADDRESSBOOK": "text/vcard", "VCALENDAR": "text/calendar"}
MAX_FILE_LOCK_DURATION = 0.25
def get_etag(text): def get_etag(text):
@ -246,6 +285,17 @@ class BaseCollection:
"""Get the unicode string representing the whole collection.""" """Get the unicode string representing the whole collection."""
raise NotImplementedError raise NotImplementedError
@classmethod
@contextmanager
def acquire_lock(cls, mode):
"""Set a context manager to lock the whole storage.
``mode`` must either be "r" for shared access or "w" for exclusive
access.
"""
raise NotImplementedError
class Collection(BaseCollection): class Collection(BaseCollection):
"""Collection stored in several files per calendar.""" """Collection stored in several files per calendar."""
@ -475,3 +525,95 @@ class Collection(BaseCollection):
elif self.get_meta("tag") == "VADDRESSBOOK": elif self.get_meta("tag") == "VADDRESSBOOK":
return "".join([item.serialize() for item in items]) return "".join([item.serialize() for item in items])
return "" return ""
_lock = threading.Lock()
_waiters = []
_lock_file = None
_lock_file_locked = False
_lock_file_time = 0
_readers = 0
_writer = False
@classmethod
@contextmanager
def acquire_lock(cls, mode):
def condition():
# Prevent starvation of writers in other processes
if cls._lock_file_locked:
time_delta = time.time() - cls._lock_file_time
if time_delta < 0 or time_delta > MAX_FILE_LOCK_DURATION:
return False
if mode == "r":
return not cls._writer
else:
return not cls._writer and cls._readers == 0
if mode not in ("r", "w"):
raise ValueError("Invalid lock mode: %s" % mode)
# Use a primitive lock which only works within one process as a
# precondition for inter-process file-based locking
with cls._lock:
if cls._waiters or not condition():
# use FIFO for access requests
waiter = threading.Condition(lock=cls._lock)
cls._waiters.append(waiter)
while True:
waiter.wait()
if condition():
break
cls._waiters.pop(0)
if mode == "r":
cls._readers += 1
# notify additional potential readers
if cls._waiters:
cls._waiters[0].notify()
else:
cls._writer = True
if not cls._lock_file:
folder = os.path.expanduser(
cls.configuration.get("storage", "filesystem_folder"))
if not os.path.exists(folder):
os.makedirs(folder, exist_ok=True)
lock_path = os.path.join(folder, "Radicale.lock")
cls._lock_file = open(lock_path, "w+")
# set access rights to a necessary minimum to prevent locking
# by arbitrary users
try:
os.chmod(lock_path, stat.S_IWUSR | stat.S_IRUSR)
except OSError:
cls.logger.debug("Failed to set permissions on lock file")
if not cls._lock_file_locked:
if os.name == "nt":
handle = msvcrt.get_osfhandle(cls._lock_file.fileno())
flags = LOCKFILE_EXCLUSIVE_LOCK if mode == "w" else 0
overlapped = Overlapped()
if not lock_file_ex(handle, flags, 0, 1, 0, overlapped):
cls.logger.debug("Locking not supported")
elif os.name == "posix":
_cmd = fcntl.LOCK_EX if mode == "w" else fcntl.LOCK_SH
try:
fcntl.lockf(cls._lock_file.fileno(), _cmd)
except OSError:
cls.logger.debug("Locking not supported")
cls._lock_file_locked = True
cls._lock_file_time = time.time()
yield
with cls._lock:
if mode == "r":
cls._readers -= 1
else:
cls._writer = False
if cls._readers == 0:
if os.name == "nt":
handle = msvcrt.get_osfhandle(cls._lock_file.fileno())
overlapped = Overlapped()
if not unlock_file_ex(handle, 0, 1, 0, overlapped):
cls.logger.debug("Unlocking not supported")
elif os.name == "posix":
try:
fcntl.lockf(cls._lock_file.fileno(), fcntl.LOCK_UN)
except OSError:
cls.logger.debug("Unlocking not supported")
cls._lock_file_locked = False
if cls._waiters:
cls._waiters[0].notify()