Atomic writes (fix #440)

This commit is contained in:
Guillaume Ayoub 2016-07-14 01:14:42 +02:00
parent 5e5427f987
commit 4c91ee8906
2 changed files with 20 additions and 4 deletions

View File

@ -38,6 +38,7 @@ from importlib import import_module
from itertools import groupby from itertools import groupby
from random import getrandbits from random import getrandbits
from atomicwrites import AtomicWriter
import vobject import vobject
if os.name == "nt": if os.name == "nt":
@ -151,6 +152,15 @@ def path_to_filesystem(root, *paths):
return safe_path return safe_path
class _EncodedAtomicWriter(AtomicWriter):
def __init__(self, path, encoding, mode="w", overwrite=True):
self._encoding = encoding
return super().__init__(path, mode, overwrite=True)
def get_fileobject(self, **kwargs):
return super().get_fileobject(encoding=self._encoding, **kwargs)
class Item: class Item:
def __init__(self, collection, item, href, last_modified=None): def __init__(self, collection, item, href, last_modified=None):
self.collection = collection self.collection = collection
@ -319,6 +329,12 @@ class Collection(BaseCollection):
self.owner = None self.owner = None
self.is_principal = principal self.is_principal = principal
@contextmanager
def _atomic_write(self, path, mode="w"):
with _EncodedAtomicWriter(
path, self.storage_encoding, mode).open() as fd:
yield fd
@classmethod @classmethod
def discover(cls, path, depth="1"): def discover(cls, path, depth="1"):
# path == None means wrong URL # path == None means wrong URL
@ -447,7 +463,7 @@ class Collection(BaseCollection):
path = path_to_filesystem(self._filesystem_path, href) path = path_to_filesystem(self._filesystem_path, href)
if not os.path.exists(path): if not os.path.exists(path):
item = Item(self, vobject_item, href) item = Item(self, vobject_item, href)
with open(path, "w", encoding=self.storage_encoding) as fd: with self._atomic_write(path) as fd:
fd.write(item.serialize()) fd.write(item.serialize())
return item return item
else: else:
@ -465,7 +481,7 @@ class Collection(BaseCollection):
text = fd.read() text = fd.read()
if not etag or etag == get_etag(text): if not etag or etag == get_etag(text):
item = Item(self, vobject_item, href) item = Item(self, vobject_item, href)
with open(path, "w", encoding=self.storage_encoding) as fd: with self._atomic_write(path) as fd:
fd.write(item.serialize()) fd.write(item.serialize())
return item return item
else: else:
@ -516,7 +532,7 @@ class Collection(BaseCollection):
else: else:
properties.pop(key, None) properties.pop(key, None)
with open(props_path, "w+", encoding=self.storage_encoding) as prop: with self._atomic_write(props_path, "w+") as prop:
json.dump(properties, prop) json.dump(properties, prop)
@property @property

View File

@ -62,7 +62,7 @@ setup(
packages=["radicale"], packages=["radicale"],
provides=["radicale"], provides=["radicale"],
scripts=["bin/radicale"], scripts=["bin/radicale"],
install_requires=["vobject"], install_requires=["vobject", "atomicwrites"],
setup_requires=["pytest-runner"], setup_requires=["pytest-runner"],
tests_require=["pytest-cov", "pytest-flake8", "pytest-isort", "pytest"], tests_require=["pytest-cov", "pytest-flake8", "pytest-isort", "pytest"],
extras_require={ extras_require={