More type hints

This commit is contained in:
Unrud 2021-07-26 20:56:46 +02:00 committed by Unrud
parent 12fe5ce637
commit cecb17df03
51 changed files with 1374 additions and 957 deletions

View File

@ -27,46 +27,48 @@ Configuration files can be specified in the environment variable
import os import os
import threading import threading
from typing import Iterable, Optional, cast
import pkg_resources import pkg_resources
from radicale import config, log from radicale import config, log, types
from radicale.app import Application from radicale.app import Application
from radicale.log import logger from radicale.log import logger
VERSION = pkg_resources.get_distribution("radicale").version VERSION: str = pkg_resources.get_distribution("radicale").version
_application = None _application_instance: Optional[Application] = None
_application_config_path = None _application_config_path: Optional[str] = None
_application_lock = threading.Lock() _application_lock = threading.Lock()
def _init_application(config_path, wsgi_errors): def _get_application_instance(config_path: str, wsgi_errors: types.ErrorStream
global _application, _application_config_path ) -> Application:
global _application_instance, _application_config_path
with _application_lock: with _application_lock:
if _application is not None: if _application_instance is None:
return log.setup()
log.setup() with log.register_stream(wsgi_errors):
with log.register_stream(wsgi_errors): _application_config_path = config_path
_application_config_path = config_path configuration = config.load(config.parse_compound_paths(
configuration = config.load(config.parse_compound_paths( config.DEFAULT_CONFIG_PATH,
config.DEFAULT_CONFIG_PATH, config_path))
config_path)) log.set_level(cast(str, configuration.get("logging", "level")))
log.set_level(configuration.get("logging", "level")) # Log configuration after logger is configured
# Log configuration after logger is configured for source, miss in configuration.sources():
for source, miss in configuration.sources(): logger.info("%s %s", "Skipped missing" if miss
logger.info("%s %s", "Skipped missing" if miss else "Loaded", else "Loaded", source)
source) _application_instance = Application(configuration)
_application = Application(configuration) if _application_config_path != config_path:
raise ValueError("RADICALE_CONFIG must not change: %r != %r" %
(config_path, _application_config_path))
return _application_instance
def application(environ, start_response): def application(environ: types.WSGIEnviron,
start_response: types.WSGIStartResponse) -> Iterable[bytes]:
"""Entry point for external WSGI servers.""" """Entry point for external WSGI servers."""
config_path = environ.get("RADICALE_CONFIG", config_path = environ.get("RADICALE_CONFIG",
os.environ.get("RADICALE_CONFIG")) os.environ.get("RADICALE_CONFIG"))
if _application is None: app = _get_application_instance(config_path, environ["wsgi.errors"])
_init_application(config_path, environ["wsgi.errors"]) return app(environ, start_response)
if _application_config_path != config_path:
raise ValueError("RADICALE_CONFIG must not change: %s != %s" %
(repr(config_path), repr(_application_config_path)))
return _application(environ, start_response)

View File

@ -29,24 +29,27 @@ import os
import signal import signal
import socket import socket
import sys import sys
from types import FrameType
from typing import Dict, List, cast
from radicale import VERSION, config, log, server, storage from radicale import VERSION, config, log, server, storage
from radicale.log import logger from radicale.log import logger
def run(): def run() -> None:
"""Run Radicale as a standalone server.""" """Run Radicale as a standalone server."""
exit_signal_numbers = [signal.SIGTERM, signal.SIGINT] exit_signal_numbers = [signal.SIGTERM, signal.SIGINT]
if os.name == "posix": if os.name == "posix":
exit_signal_numbers.append(signal.SIGHUP) exit_signal_numbers.append(signal.SIGHUP)
exit_signal_numbers.append(signal.SIGQUIT) exit_signal_numbers.append(signal.SIGQUIT)
elif os.name == "nt": if sys.platform == "win32":
exit_signal_numbers.append(signal.SIGBREAK) exit_signal_numbers.append(signal.SIGBREAK)
# Raise SystemExit when signal arrives to run cleanup code # Raise SystemExit when signal arrives to run cleanup code
# (like destructors, try-finish etc.), otherwise the process exits # (like destructors, try-finish etc.), otherwise the process exits
# without running any of them # without running any of them
def exit_signal_handler(signal_number, stack_frame): def exit_signal_handler(signal_number: "signal.Signals",
stack_frame: FrameType) -> None:
sys.exit(1) sys.exit(1)
for signal_number in exit_signal_numbers: for signal_number in exit_signal_numbers:
signal.signal(signal_number, exit_signal_handler) signal.signal(signal_number, exit_signal_handler)
@ -60,12 +63,12 @@ def run():
parser.add_argument("--version", action="version", version=VERSION) parser.add_argument("--version", action="version", version=VERSION)
parser.add_argument("--verify-storage", action="store_true", parser.add_argument("--verify-storage", action="store_true",
help="check the storage for errors and exit") help="check the storage for errors and exit")
parser.add_argument( parser.add_argument("-C", "--config",
"-C", "--config", help="use specific configuration files", nargs="*") help="use specific configuration files", nargs="*")
parser.add_argument("-D", "--debug", action="store_true", parser.add_argument("-D", "--debug", action="store_true",
help="print debug information") help="print debug information")
groups = {} groups: Dict["argparse._ArgumentGroup", List[str]] = {}
for section, values in config.DEFAULT_CONFIG_SCHEMA.items(): for section, values in config.DEFAULT_CONFIG_SCHEMA.items():
if section.startswith("_"): if section.startswith("_"):
continue continue
@ -76,7 +79,7 @@ def run():
continue continue
kwargs = data.copy() kwargs = data.copy()
long_name = "--%s-%s" % (section, option.replace("_", "-")) long_name = "--%s-%s" % (section, option.replace("_", "-"))
args = list(kwargs.pop("aliases", ())) args: List[str] = list(kwargs.pop("aliases", ()))
args.append(long_name) args.append(long_name)
kwargs["dest"] = "%s_%s" % (section, option) kwargs["dest"] = "%s_%s" % (section, option)
groups[group].append(kwargs["dest"]) groups[group].append(kwargs["dest"])
@ -100,22 +103,22 @@ def run():
del kwargs["type"] del kwargs["type"]
group.add_argument(*args, **kwargs) group.add_argument(*args, **kwargs)
args = parser.parse_args() args_ns = parser.parse_args()
# Preliminary configure logging # Preliminary configure logging
if args.debug: if args_ns.debug:
args.logging_level = "debug" args_ns.logging_level = "debug"
with contextlib.suppress(ValueError): with contextlib.suppress(ValueError):
log.set_level(config.DEFAULT_CONFIG_SCHEMA["logging"]["level"]["type"]( log.set_level(config.DEFAULT_CONFIG_SCHEMA["logging"]["level"]["type"](
args.logging_level)) args_ns.logging_level))
# Update Radicale configuration according to arguments # Update Radicale configuration according to arguments
arguments_config = {} arguments_config = {}
for group, actions in groups.items(): for group, actions in groups.items():
section = group.title section = group.title or ""
section_config = {} section_config = {}
for action in actions: for action in actions:
value = getattr(args, action) value = getattr(args_ns, action)
if value is not None: if value is not None:
section_config[action.split('_', 1)[1]] = value section_config[action.split('_', 1)[1]] = value
if section_config: if section_config:
@ -125,31 +128,31 @@ def run():
configuration = config.load(config.parse_compound_paths( configuration = config.load(config.parse_compound_paths(
config.DEFAULT_CONFIG_PATH, config.DEFAULT_CONFIG_PATH,
os.environ.get("RADICALE_CONFIG"), os.environ.get("RADICALE_CONFIG"),
os.pathsep.join(args.config) if args.config else None)) os.pathsep.join(args_ns.config) if args_ns.config else None))
if arguments_config: if arguments_config:
configuration.update(arguments_config, "arguments") configuration.update(arguments_config, "command line arguments")
except Exception as e: except Exception as e:
logger.fatal("Invalid configuration: %s", e, exc_info=True) logger.critical("Invalid configuration: %s", e, exc_info=True)
sys.exit(1) sys.exit(1)
# Configure logging # Configure logging
log.set_level(configuration.get("logging", "level")) log.set_level(cast(str, configuration.get("logging", "level")))
# Log configuration after logger is configured # Log configuration after logger is configured
for source, miss in configuration.sources(): for source, miss in configuration.sources():
logger.info("%s %s", "Skipped missing" if miss else "Loaded", source) logger.info("%s %s", "Skipped missing" if miss else "Loaded", source)
if args.verify_storage: if args_ns.verify_storage:
logger.info("Verifying storage") logger.info("Verifying storage")
try: try:
storage_ = storage.load(configuration) storage_ = storage.load(configuration)
with storage_.acquire_lock("r"): with storage_.acquire_lock("r"):
if not storage_.verify(): if not storage_.verify():
logger.fatal("Storage verifcation failed") logger.critical("Storage verifcation failed")
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
logger.fatal("An exception occurred during storage verification: " logger.critical("An exception occurred during storage "
"%s", e, exc_info=True) "verification: %s", e, exc_info=True)
sys.exit(1) sys.exit(1)
return return
@ -157,7 +160,8 @@ def run():
shutdown_socket, shutdown_socket_out = socket.socketpair() shutdown_socket, shutdown_socket_out = socket.socketpair()
# Shutdown server when signal arrives # Shutdown server when signal arrives
def shutdown_signal_handler(signal_number, stack_frame): def shutdown_signal_handler(signal_number: "signal.Signals",
stack_frame: FrameType) -> None:
shutdown_socket.close() shutdown_socket.close()
for signal_number in exit_signal_numbers: for signal_number in exit_signal_numbers:
signal.signal(signal_number, shutdown_signal_handler) signal.signal(signal_number, shutdown_signal_handler)
@ -165,8 +169,8 @@ def run():
try: try:
server.serve(configuration, shutdown_socket_out) server.serve(configuration, shutdown_socket_out)
except Exception as e: except Exception as e:
logger.fatal("An exception occurred during server startup: %s", e, logger.critical("An exception occurred during server startup: %s", e,
exc_info=True) exc_info=True)
sys.exit(1) sys.exit(1)

View File

@ -27,53 +27,53 @@ the built-in server (see ``radicale.server`` module).
import base64 import base64
import datetime import datetime
import io
import logging
import posixpath
import pprint import pprint
import random import random
import sys
import time import time
import xml.etree.ElementTree as ET
import zlib import zlib
from http import client from http import client
from typing import Iterable, List, Mapping, Tuple, Union
import pkg_resources import pkg_resources
from radicale import (auth, httputils, log, pathutils, rights, storage, web, from radicale import config, httputils, log, pathutils, types
xmlutils) from radicale.app.base import ApplicationBase
from radicale.app.delete import ApplicationDeleteMixin from radicale.app.delete import ApplicationPartDelete
from radicale.app.get import ApplicationGetMixin from radicale.app.get import ApplicationPartGet
from radicale.app.head import ApplicationHeadMixin from radicale.app.head import ApplicationPartHead
from radicale.app.mkcalendar import ApplicationMkcalendarMixin from radicale.app.mkcalendar import ApplicationPartMkcalendar
from radicale.app.mkcol import ApplicationMkcolMixin from radicale.app.mkcol import ApplicationPartMkcol
from radicale.app.move import ApplicationMoveMixin from radicale.app.move import ApplicationPartMove
from radicale.app.options import ApplicationOptionsMixin from radicale.app.options import ApplicationPartOptions
from radicale.app.post import ApplicationPostMixin from radicale.app.post import ApplicationPartPost
from radicale.app.propfind import ApplicationPropfindMixin from radicale.app.propfind import ApplicationPartPropfind
from radicale.app.proppatch import ApplicationProppatchMixin from radicale.app.proppatch import ApplicationPartProppatch
from radicale.app.put import ApplicationPutMixin from radicale.app.put import ApplicationPartPut
from radicale.app.report import ApplicationReportMixin from radicale.app.report import ApplicationPartReport
from radicale.log import logger from radicale.log import logger
# WORKAROUND: https://github.com/tiran/defusedxml/issues/54 VERSION: str = pkg_resources.get_distribution("radicale").version
import defusedxml.ElementTree as DefusedET # isort: skip
sys.modules["xml.etree"].ElementTree = ET # type: ignore[attr-defined]
VERSION = pkg_resources.get_distribution("radicale").version # Combination of types.WSGIStartResponse and WSGI application return value
_IntermediateResponse = Tuple[str, List[Tuple[str, str]], Iterable[bytes]]
class Application( class Application(ApplicationPartDelete, ApplicationPartHead,
ApplicationDeleteMixin, ApplicationGetMixin, ApplicationHeadMixin, ApplicationPartGet, ApplicationPartMkcalendar,
ApplicationMkcalendarMixin, ApplicationMkcolMixin, ApplicationPartMkcol, ApplicationPartMove,
ApplicationMoveMixin, ApplicationOptionsMixin, ApplicationPartOptions, ApplicationPartPropfind,
ApplicationPropfindMixin, ApplicationProppatchMixin, ApplicationPartProppatch, ApplicationPartPost,
ApplicationPostMixin, ApplicationPutMixin, ApplicationPartPut, ApplicationPartReport, ApplicationBase):
ApplicationReportMixin):
"""WSGI application.""" """WSGI application."""
def __init__(self, configuration): _mask_passwords: bool
_auth_delay: float
_internal_server: bool
_max_content_length: int
_auth_realm: str
_extra_headers: Mapping[str, str]
def __init__(self, configuration: config.Configuration) -> None:
"""Initialize Application. """Initialize Application.
``configuration`` see ``radicale.config`` module. ``configuration`` see ``radicale.config`` module.
@ -81,60 +81,59 @@ class Application(
this object, it is kept as an internal reference. this object, it is kept as an internal reference.
""" """
super().__init__() super().__init__(configuration)
self.configuration = configuration self._mask_passwords = configuration.get("logging", "mask_passwords")
self._auth = auth.load(configuration) self._auth_delay = configuration.get("auth", "delay")
self._storage = storage.load(configuration) self._internal_server = configuration.get("server", "_internal_server")
self._rights = rights.load(configuration) self._max_content_length = configuration.get(
self._web = web.load(configuration) "server", "max_content_length")
self._encoding = configuration.get("encoding", "request") self._auth_realm = configuration.get("auth", "realm")
self._extra_headers = dict()
for key in self.configuration.options("headers"):
self._extra_headers[key] = configuration.get("headers", key)
def _headers_log(self, environ): def _scrub_headers(self, environ: types.WSGIEnviron) -> types.WSGIEnviron:
"""Sanitize headers for logging.""" """Mask passwords and cookies."""
request_environ = dict(environ) headers = dict(environ)
if (self._mask_passwords and
headers.get("HTTP_AUTHORIZATION", "").startswith("Basic")):
headers["HTTP_AUTHORIZATION"] = "Basic **masked**"
if headers.get("HTTP_COOKIE"):
headers["HTTP_COOKIE"] = "**masked**"
return headers
# Mask passwords def __call__(self, environ: types.WSGIEnviron, start_response:
mask_passwords = self.configuration.get("logging", "mask_passwords") types.WSGIStartResponse) -> Iterable[bytes]:
authorization = request_environ.get("HTTP_AUTHORIZATION", "")
if mask_passwords and authorization.startswith("Basic"):
request_environ["HTTP_AUTHORIZATION"] = "Basic **masked**"
if request_environ.get("HTTP_COOKIE"):
request_environ["HTTP_COOKIE"] = "**masked**"
return request_environ
def __call__(self, environ, start_response):
with log.register_stream(environ["wsgi.errors"]): with log.register_stream(environ["wsgi.errors"]):
try: try:
status, headers, answers = self._handle_request(environ) status_text, headers, answers = self._handle_request(environ)
except Exception as e: except Exception as e:
try:
method = str(environ["REQUEST_METHOD"])
except Exception:
method = "unknown"
try:
path = str(environ.get("PATH_INFO", ""))
except Exception:
path = ""
logger.error("An exception occurred during %s request on %r: " logger.error("An exception occurred during %s request on %r: "
"%s", method, path, e, exc_info=True) "%s", environ.get("REQUEST_METHOD", "unknown"),
status, headers, answer = httputils.INTERNAL_SERVER_ERROR environ.get("PATH_INFO", ""), e, exc_info=True)
answer = answer.encode("ascii") # Make minimal response
status = "%d %s" % ( status, raw_headers, raw_answer = (
status.value, client.responses.get(status, "Unknown")) httputils.INTERNAL_SERVER_ERROR)
headers = [ assert isinstance(raw_answer, str)
("Content-Length", str(len(answer)))] + list(headers) answer = raw_answer.encode("ascii")
status_text = "%d %s" % (
status, client.responses.get(status, "Unknown"))
headers = [*raw_headers, ("Content-Length", str(len(answer)))]
answers = [answer] answers = [answer]
start_response(status, headers) start_response(status_text, headers)
return answers return answers
def _handle_request(self, environ): def _handle_request(self, environ: types.WSGIEnviron
) -> _IntermediateResponse:
"""Manage a request.""" """Manage a request."""
def response(status, headers=(), answer=None): def response(status: int, headers: types.WSGIResponseHeaders,
answer: Union[None, str, bytes]) -> _IntermediateResponse:
"""Helper to create response from internal types.WSGIResponse"""
headers = dict(headers) headers = dict(headers)
# Set content length # Set content length
if answer: answers = []
if hasattr(answer, "encode"): if answer is not None:
if isinstance(answer, str):
logger.debug("Response content:\n%s", answer) logger.debug("Response content:\n%s", answer)
headers["Content-Type"] += "; charset=%s" % self._encoding headers["Content-Type"] += "; charset=%s" % self._encoding
answer = answer.encode(self._encoding) answer = answer.encode(self._encoding)
@ -149,21 +148,22 @@ class Application(
headers["Content-Encoding"] = "gzip" headers["Content-Encoding"] = "gzip"
headers["Content-Length"] = str(len(answer)) headers["Content-Length"] = str(len(answer))
answers.append(answer)
# Add extra headers set in configuration # Add extra headers set in configuration
for key in self.configuration.options("headers"): headers.update(self._extra_headers)
headers[key] = self.configuration.get("headers", key)
# Start response # Start response
time_end = datetime.datetime.now() time_end = datetime.datetime.now()
status = "%d %s" % ( status_text = "%d %s" % (
status, client.responses.get(status, "Unknown")) status, client.responses.get(status, "Unknown"))
logger.info( logger.info(
"%s response status for %r%s in %.3f seconds: %s", "%s response status for %r%s in %.3f seconds: %s",
environ["REQUEST_METHOD"], environ.get("PATH_INFO", ""), environ["REQUEST_METHOD"], environ.get("PATH_INFO", ""),
depthinfo, (time_end - time_begin).total_seconds(), status) depthinfo, (time_end - time_begin).total_seconds(),
status_text)
# Return response content # Return response content
return status, list(headers.items()), [answer] if answer else [] return status_text, list(headers.items()), answers
remote_host = "unknown" remote_host = "unknown"
if environ.get("REMOTE_HOST"): if environ.get("REMOTE_HOST"):
@ -184,8 +184,8 @@ class Application(
"%s request for %r%s received from %s%s", "%s request for %r%s received from %s%s",
environ["REQUEST_METHOD"], environ.get("PATH_INFO", ""), depthinfo, environ["REQUEST_METHOD"], environ.get("PATH_INFO", ""), depthinfo,
remote_host, remote_useragent) remote_host, remote_useragent)
headers = pprint.pformat(self._headers_log(environ)) logger.debug("Request headers:\n%s",
logger.debug("Request headers:\n%s", headers) pprint.pformat(self._scrub_headers(environ)))
# Let reverse proxies overwrite SCRIPT_NAME # Let reverse proxies overwrite SCRIPT_NAME
if "HTTP_X_SCRIPT_NAME" in environ: if "HTTP_X_SCRIPT_NAME" in environ:
@ -237,9 +237,8 @@ class Application(
logger.warning("Failed login attempt from %s: %r", logger.warning("Failed login attempt from %s: %r",
remote_host, login) remote_host, login)
# Random delay to avoid timing oracles and bruteforce attacks # Random delay to avoid timing oracles and bruteforce attacks
delay = self.configuration.get("auth", "delay") if self._auth_delay > 0:
if delay > 0: random_delay = self._auth_delay * (0.5 + random.random())
random_delay = delay * (0.5 + random.random())
logger.debug("Sleeping %.3f seconds", random_delay) logger.debug("Sleeping %.3f seconds", random_delay)
time.sleep(random_delay) time.sleep(random_delay)
@ -252,8 +251,8 @@ class Application(
if user: if user:
principal_path = "/%s/" % user principal_path = "/%s/" % user
with self._storage.acquire_lock("r", user): with self._storage.acquire_lock("r", user):
principal = next(self._storage.discover( principal = next(iter(self._storage.discover(
principal_path, depth="1"), None) principal_path, depth="1")), None)
if not principal: if not principal:
if "W" in self._rights.authorization(user, principal_path): if "W" in self._rights.authorization(user, principal_path):
with self._storage.acquire_lock("w", user): with self._storage.acquire_lock("w", user):
@ -267,13 +266,12 @@ class Application(
logger.warning("Access to principal path %r denied by " logger.warning("Access to principal path %r denied by "
"rights backend", principal_path) "rights backend", principal_path)
if self.configuration.get("server", "_internal_server"): if self._internal_server:
# Verify content length # Verify content length
content_length = int(environ.get("CONTENT_LENGTH") or 0) content_length = int(environ.get("CONTENT_LENGTH") or 0)
if content_length: if content_length:
max_content_length = self.configuration.get( if (self._max_content_length > 0 and
"server", "max_content_length") content_length > self._max_content_length):
if max_content_length and content_length > max_content_length:
logger.info("Request body too large: %d", content_length) logger.info("Request body too large: %d", content_length)
return response(*httputils.REQUEST_ENTITY_TOO_LARGE) return response(*httputils.REQUEST_ENTITY_TOO_LARGE)
@ -291,82 +289,9 @@ class Application(
# Unknown or unauthorized user # Unknown or unauthorized user
logger.debug("Asking client for authentication") logger.debug("Asking client for authentication")
status = client.UNAUTHORIZED status = client.UNAUTHORIZED
realm = self.configuration.get("auth", "realm")
headers = dict(headers) headers = dict(headers)
headers.update({ headers.update({
"WWW-Authenticate": "WWW-Authenticate":
"Basic realm=\"%s\"" % realm}) "Basic realm=\"%s\"" % self._auth_realm})
return response(status, headers, answer) return response(status, headers, answer)
def _read_xml_request_body(self, environ):
content = httputils.decode_request(
self.configuration, environ,
httputils.read_raw_request_body(self.configuration, environ))
if not content:
return None
try:
xml_content = DefusedET.fromstring(content)
except ET.ParseError as e:
logger.debug("Request content (Invalid XML):\n%s", content)
raise RuntimeError("Failed to parse XML: %s" % e) from e
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Request content:\n%s",
xmlutils.pretty_xml(xml_content))
return xml_content
def _xml_response(self, xml_content):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Response content:\n%s",
xmlutils.pretty_xml(xml_content))
f = io.BytesIO()
ET.ElementTree(xml_content).write(f, encoding=self._encoding,
xml_declaration=True)
return f.getvalue()
def _webdav_error_response(self, status, human_tag):
"""Generate XML error response."""
headers = {"Content-Type": "text/xml; charset=%s" % self._encoding}
content = self._xml_response(xmlutils.webdav_error(human_tag))
return status, headers, content
class Access:
"""Helper class to check access rights of an item"""
def __init__(self, rights, user, path):
self._rights = rights
self.user = user
self.path = path
self.parent_path = pathutils.unstrip_path(
posixpath.dirname(pathutils.strip_path(path)), True)
self.permissions = self._rights.authorization(self.user, self.path)
self._parent_permissions = None
@property
def parent_permissions(self):
if self.path == self.parent_path:
return self.permissions
if self._parent_permissions is None:
self._parent_permissions = self._rights.authorization(
self.user, self.parent_path)
return self._parent_permissions
def check(self, permission, item=None):
if permission not in "rw":
raise ValueError("Invalid permission argument: %r" % permission)
if not item:
permissions = permission + permission.upper()
parent_permissions = permission
elif isinstance(item, storage.BaseCollection):
if item.get_meta("tag"):
permissions = permission
else:
permissions = permission.upper()
parent_permissions = ""
else:
permissions = ""
parent_permissions = permission
return bool(rights.intersect(self.permissions, permissions) or (
self.path != self.parent_path and
rights.intersect(self.parent_permissions, parent_permissions)))

131
radicale/app/base.py Normal file
View File

@ -0,0 +1,131 @@
# This file is part of Radicale Server - Calendar Server
# Copyright © 2020 Unrud <unrud@outlook.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Radicale. If not, see <http://www.gnu.org/licenses/>.
import io
import logging
import posixpath
import sys
import xml.etree.ElementTree as ET
from typing import Optional
from radicale import (auth, config, httputils, pathutils, rights, storage,
types, web, xmlutils)
from radicale.log import logger
# HACK: https://github.com/tiran/defusedxml/issues/54
import defusedxml.ElementTree as DefusedET # isort:skip
sys.modules["xml.etree"].ElementTree = ET # type:ignore[attr-defined]
class ApplicationBase:
configuration: config.Configuration
_auth: auth.BaseAuth
_storage: storage.BaseStorage
_rights: rights.BaseRights
_web: web.BaseWeb
_encoding: str
def __init__(self, configuration: config.Configuration) -> None:
self.configuration = configuration
self._auth = auth.load(configuration)
self._storage = storage.load(configuration)
self._rights = rights.load(configuration)
self._web = web.load(configuration)
self._encoding = configuration.get("encoding", "request")
def _read_xml_request_body(self, environ: types.WSGIEnviron
) -> Optional[ET.Element]:
content = httputils.decode_request(
self.configuration, environ,
httputils.read_raw_request_body(self.configuration, environ))
if not content:
return None
try:
xml_content = DefusedET.fromstring(content)
except ET.ParseError as e:
logger.debug("Request content (Invalid XML):\n%s", content)
raise RuntimeError("Failed to parse XML: %s" % e) from e
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Request content:\n%s",
xmlutils.pretty_xml(xml_content))
return xml_content
def _xml_response(self, xml_content: ET.Element) -> bytes:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Response content:\n%s",
xmlutils.pretty_xml(xml_content))
f = io.BytesIO()
ET.ElementTree(xml_content).write(f, encoding=self._encoding,
xml_declaration=True)
return f.getvalue()
def _webdav_error_response(self, status: int, human_tag: str
) -> types.WSGIResponse:
"""Generate XML error response."""
headers = {"Content-Type": "text/xml; charset=%s" % self._encoding}
content = self._xml_response(xmlutils.webdav_error(human_tag))
return status, headers, content
class Access:
"""Helper class to check access rights of an item"""
user: str
path: str
parent_path: str
permissions: str
_rights: rights.BaseRights
_parent_permissions: Optional[str]
def __init__(self, rights: rights.BaseRights, user: str, path: str
) -> None:
self._rights = rights
self.user = user
self.path = path
self.parent_path = pathutils.unstrip_path(
posixpath.dirname(pathutils.strip_path(path)), True)
self.permissions = self._rights.authorization(self.user, self.path)
self._parent_permissions = None
@property
def parent_permissions(self) -> str:
if self.path == self.parent_path:
return self.permissions
if self._parent_permissions is None:
self._parent_permissions = self._rights.authorization(
self.user, self.parent_path)
return self._parent_permissions
def check(self, permission: str,
item: Optional[types.CollectionOrItem] = None) -> bool:
if permission not in "rw":
raise ValueError("Invalid permission argument: %r" % permission)
if not item:
permissions = permission + permission.upper()
parent_permissions = permission
elif isinstance(item, storage.BaseCollection):
if item.tag:
permissions = permission
else:
permissions = permission.upper()
parent_permissions = ""
else:
permissions = ""
parent_permissions = permission
return bool(rights.intersect(self.permissions, permissions) or (
self.path != self.parent_path and
rights.intersect(self.parent_permissions, parent_permissions)))

View File

@ -19,25 +19,28 @@
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import client from http import client
from typing import Optional
from radicale import app, httputils, storage, xmlutils from radicale import httputils, storage, types, xmlutils
from radicale.app.base import Access, ApplicationBase
def xml_delete(base_prefix, path, collection, href=None): def xml_delete(base_prefix: str, path: str, collection: storage.BaseCollection,
item_href: Optional[str] = None) -> ET.Element:
"""Read and answer DELETE requests. """Read and answer DELETE requests.
Read rfc4918-9.6 for info. Read rfc4918-9.6 for info.
""" """
collection.delete(href) collection.delete(item_href)
multistatus = ET.Element(xmlutils.make_clark("D:multistatus")) multistatus = ET.Element(xmlutils.make_clark("D:multistatus"))
response = ET.Element(xmlutils.make_clark("D:response")) response = ET.Element(xmlutils.make_clark("D:response"))
multistatus.append(response) multistatus.append(response)
href = ET.Element(xmlutils.make_clark("D:href")) href_element = ET.Element(xmlutils.make_clark("D:href"))
href.text = xmlutils.make_href(base_prefix, path) href_element.text = xmlutils.make_href(base_prefix, path)
response.append(href) response.append(href_element)
status = ET.Element(xmlutils.make_clark("D:status")) status = ET.Element(xmlutils.make_clark("D:status"))
status.text = xmlutils.make_response(200) status.text = xmlutils.make_response(200)
@ -46,14 +49,16 @@ def xml_delete(base_prefix, path, collection, href=None):
return multistatus return multistatus
class ApplicationDeleteMixin: class ApplicationPartDelete(ApplicationBase):
def do_DELETE(self, environ, base_prefix, path, user):
def do_DELETE(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage DELETE request.""" """Manage DELETE request."""
access = app.Access(self._rights, user, path) access = Access(self._rights, user, path)
if not access.check("w"): if not access.check("w"):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
with self._storage.acquire_lock("w", user): with self._storage.acquire_lock("w", user):
item = next(self._storage.discover(path), None) item = next(iter(self._storage.discover(path)), None)
if not item: if not item:
return httputils.NOT_FOUND return httputils.NOT_FOUND
if not access.check("w", item): if not access.check("w", item):
@ -65,6 +70,8 @@ class ApplicationDeleteMixin:
if isinstance(item, storage.BaseCollection): if isinstance(item, storage.BaseCollection):
xml_answer = xml_delete(base_prefix, path, item) xml_answer = xml_delete(base_prefix, path, item)
else: else:
assert item.collection is not None
assert item.href is not None
xml_answer = xml_delete( xml_answer = xml_delete(
base_prefix, path, item.collection, item.href) base_prefix, path, item.collection, item.href)
headers = {"Content-Type": "text/xml; charset=%s" % self._encoding} headers = {"Content-Type": "text/xml; charset=%s" % self._encoding}

View File

@ -21,17 +21,17 @@ import posixpath
from http import client from http import client
from urllib.parse import quote from urllib.parse import quote
from radicale import app, httputils, pathutils, storage, xmlutils from radicale import httputils, pathutils, storage, types, xmlutils
from radicale.app.base import Access, ApplicationBase
from radicale.log import logger from radicale.log import logger
def propose_filename(collection): def propose_filename(collection: storage.BaseCollection) -> str:
"""Propose a filename for a collection.""" """Propose a filename for a collection."""
tag = collection.get_meta("tag") if collection.tag == "VADDRESSBOOK":
if tag == "VADDRESSBOOK":
fallback_title = "Address book" fallback_title = "Address book"
suffix = ".vcf" suffix = ".vcf"
elif tag == "VCALENDAR": elif collection.tag == "VCALENDAR":
fallback_title = "Calendar" fallback_title = "Calendar"
suffix = ".ics" suffix = ".ics"
else: else:
@ -43,8 +43,9 @@ def propose_filename(collection):
return title return title
class ApplicationGetMixin: class ApplicationPartGet(ApplicationBase):
def _content_disposition_attachement(self, filename):
def _content_disposition_attachement(self, filename: str) -> str:
value = "attachement" value = "attachement"
try: try:
encoded_filename = quote(filename, encoding=self._encoding) encoded_filename = quote(filename, encoding=self._encoding)
@ -56,7 +57,8 @@ class ApplicationGetMixin:
value += "; filename*=%s''%s" % (self._encoding, encoded_filename) value += "; filename*=%s''%s" % (self._encoding, encoded_filename)
return value return value
def do_GET(self, environ, base_prefix, path, user): def do_GET(self, environ: types.WSGIEnviron, base_prefix: str, path: str,
user: str) -> types.WSGIResponse:
"""Manage GET request.""" """Manage GET request."""
# Redirect to .web if the root URL is requested # Redirect to .web if the root URL is requested
if not pathutils.strip_path(path): if not pathutils.strip_path(path):
@ -70,11 +72,11 @@ class ApplicationGetMixin:
# Dispatch .web URL to web module # Dispatch .web URL to web module
if path == "/.web" or path.startswith("/.web/"): if path == "/.web" or path.startswith("/.web/"):
return self._web.get(environ, base_prefix, path, user) return self._web.get(environ, base_prefix, path, user)
access = app.Access(self._rights, user, path) access = Access(self._rights, user, path)
if not access.check("r") and "i" not in access.permissions: if not access.check("r") and "i" not in access.permissions:
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
with self._storage.acquire_lock("r", user): with self._storage.acquire_lock("r", user):
item = next(self._storage.discover(path), None) item = next(iter(self._storage.discover(path)), None)
if not item: if not item:
return httputils.NOT_FOUND return httputils.NOT_FOUND
if access.check("r", item): if access.check("r", item):
@ -84,11 +86,10 @@ class ApplicationGetMixin:
else: else:
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
if isinstance(item, storage.BaseCollection): if isinstance(item, storage.BaseCollection):
tag = item.get_meta("tag") if not item.tag:
if not tag:
return (httputils.NOT_ALLOWED if limited_access else return (httputils.NOT_ALLOWED if limited_access else
httputils.DIRECTORY_LISTING) httputils.DIRECTORY_LISTING)
content_type = xmlutils.MIMETYPES[tag] content_type = xmlutils.MIMETYPES[item.tag]
content_disposition = self._content_disposition_attachement( content_disposition = self._content_disposition_attachement(
propose_filename(item)) propose_filename(item))
elif limited_access: elif limited_access:
@ -96,6 +97,7 @@ class ApplicationGetMixin:
else: else:
content_type = xmlutils.OBJECT_MIMETYPES[item.name] content_type = xmlutils.OBJECT_MIMETYPES[item.name]
content_disposition = "" content_disposition = ""
assert item.last_modified
headers = { headers = {
"Content-Type": content_type, "Content-Type": content_type,
"Last-Modified": item.last_modified, "Last-Modified": item.last_modified,

View File

@ -17,9 +17,15 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Radicale. If not, see <http://www.gnu.org/licenses/>. # along with Radicale. If not, see <http://www.gnu.org/licenses/>.
from radicale import types
from radicale.app.base import ApplicationBase
from radicale.app.get import ApplicationPartGet
class ApplicationHeadMixin:
def do_HEAD(self, environ, base_prefix, path, user): class ApplicationPartHead(ApplicationPartGet, ApplicationBase):
def do_HEAD(self, environ: types.WSGIEnviron, base_prefix: str, path: str,
user: str) -> types.WSGIResponse:
"""Manage HEAD request.""" """Manage HEAD request."""
status, headers, _ = self.do_GET(environ, base_prefix, path, user) status, headers, _ = self.do_GET(environ, base_prefix, path, user)
return status, headers, None return status, headers, None

View File

@ -21,14 +21,16 @@ import posixpath
import socket import socket
from http import client from http import client
from radicale import httputils import radicale.item as radicale_item
from radicale import item as radicale_item from radicale import httputils, pathutils, storage, types, xmlutils
from radicale import pathutils, storage, xmlutils from radicale.app.base import ApplicationBase
from radicale.log import logger from radicale.log import logger
class ApplicationMkcalendarMixin: class ApplicationPartMkcalendar(ApplicationBase):
def do_MKCALENDAR(self, environ, base_prefix, path, user):
def do_MKCALENDAR(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage MKCALENDAR request.""" """Manage MKCALENDAR request."""
if "w" not in self._rights.authorization(user, path): if "w" not in self._rights.authorization(user, path):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
@ -42,29 +44,28 @@ class ApplicationMkcalendarMixin:
logger.debug("Client timed out", exc_info=True) logger.debug("Client timed out", exc_info=True)
return httputils.REQUEST_TIMEOUT return httputils.REQUEST_TIMEOUT
# Prepare before locking # Prepare before locking
props = xmlutils.props_from_request(xml_content) props_with_remove = xmlutils.props_from_request(xml_content)
props = {k: v for k, v in props.items() if v is not None} props_with_remove["tag"] = "VCALENDAR"
props["tag"] = "VCALENDAR"
# TODO: use this?
# timezone = props.get("C:calendar-timezone")
try: try:
radicale_item.check_and_sanitize_props(props) props = radicale_item.check_and_sanitize_props(props_with_remove)
except ValueError as e: except ValueError as e:
logger.warning( logger.warning(
"Bad MKCALENDAR request on %r: %s", path, e, exc_info=True) "Bad MKCALENDAR request on %r: %s", path, e, exc_info=True)
return httputils.BAD_REQUEST return httputils.BAD_REQUEST
# TODO: use this?
# timezone = props.get("C:calendar-timezone")
with self._storage.acquire_lock("w", user): with self._storage.acquire_lock("w", user):
item = next(self._storage.discover(path), None) item = next(iter(self._storage.discover(path)), None)
if item: if item:
return self._webdav_error_response( return self._webdav_error_response(
client.CONFLICT, "D:resource-must-be-null") client.CONFLICT, "D:resource-must-be-null")
parent_path = pathutils.unstrip_path( parent_path = pathutils.unstrip_path(
posixpath.dirname(pathutils.strip_path(path)), True) posixpath.dirname(pathutils.strip_path(path)), True)
parent_item = next(self._storage.discover(parent_path), None) parent_item = next(iter(self._storage.discover(parent_path)), None)
if not parent_item: if not parent_item:
return httputils.CONFLICT return httputils.CONFLICT
if (not isinstance(parent_item, storage.BaseCollection) or if (not isinstance(parent_item, storage.BaseCollection) or
parent_item.get_meta("tag")): parent_item.tag):
return httputils.FORBIDDEN return httputils.FORBIDDEN
try: try:
self._storage.create_collection(path, props=props) self._storage.create_collection(path, props=props)

View File

@ -21,14 +21,16 @@ import posixpath
import socket import socket
from http import client from http import client
from radicale import httputils import radicale.item as radicale_item
from radicale import item as radicale_item from radicale import httputils, pathutils, rights, storage, types, xmlutils
from radicale import pathutils, rights, storage, xmlutils from radicale.app.base import ApplicationBase
from radicale.log import logger from radicale.log import logger
class ApplicationMkcolMixin: class ApplicationPartMkcol(ApplicationBase):
def do_MKCOL(self, environ, base_prefix, path, user):
def do_MKCOL(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage MKCOL request.""" """Manage MKCOL request."""
permissions = self._rights.authorization(user, path) permissions = self._rights.authorization(user, path)
if not rights.intersect(permissions, "Ww"): if not rights.intersect(permissions, "Ww"):
@ -43,10 +45,9 @@ class ApplicationMkcolMixin:
logger.debug("Client timed out", exc_info=True) logger.debug("Client timed out", exc_info=True)
return httputils.REQUEST_TIMEOUT return httputils.REQUEST_TIMEOUT
# Prepare before locking # Prepare before locking
props = xmlutils.props_from_request(xml_content) props_with_remove = xmlutils.props_from_request(xml_content)
props = {k: v for k, v in props.items() if v is not None}
try: try:
radicale_item.check_and_sanitize_props(props) props = radicale_item.check_and_sanitize_props(props_with_remove)
except ValueError as e: except ValueError as e:
logger.warning( logger.warning(
"Bad MKCOL request on %r: %s", path, e, exc_info=True) "Bad MKCOL request on %r: %s", path, e, exc_info=True)
@ -55,16 +56,16 @@ class ApplicationMkcolMixin:
not props.get("tag") and "W" not in permissions): not props.get("tag") and "W" not in permissions):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
with self._storage.acquire_lock("w", user): with self._storage.acquire_lock("w", user):
item = next(self._storage.discover(path), None) item = next(iter(self._storage.discover(path)), None)
if item: if item:
return httputils.METHOD_NOT_ALLOWED return httputils.METHOD_NOT_ALLOWED
parent_path = pathutils.unstrip_path( parent_path = pathutils.unstrip_path(
posixpath.dirname(pathutils.strip_path(path)), True) posixpath.dirname(pathutils.strip_path(path)), True)
parent_item = next(self._storage.discover(parent_path), None) parent_item = next(iter(self._storage.discover(parent_path)), None)
if not parent_item: if not parent_item:
return httputils.CONFLICT return httputils.CONFLICT
if (not isinstance(parent_item, storage.BaseCollection) or if (not isinstance(parent_item, storage.BaseCollection) or
parent_item.get_meta("tag")): parent_item.tag):
return httputils.FORBIDDEN return httputils.FORBIDDEN
try: try:
self._storage.create_collection(path, props=props) self._storage.create_collection(path, props=props)

View File

@ -21,12 +21,15 @@ import posixpath
from http import client from http import client
from urllib.parse import urlparse from urllib.parse import urlparse
from radicale import app, httputils, pathutils, storage from radicale import httputils, pathutils, storage, types
from radicale.app.base import Access, ApplicationBase
from radicale.log import logger from radicale.log import logger
class ApplicationMoveMixin: class ApplicationPartMove(ApplicationBase):
def do_MOVE(self, environ, base_prefix, path, user):
def do_MOVE(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage MOVE request.""" """Manage MOVE request."""
raw_dest = environ.get("HTTP_DESTINATION", "") raw_dest = environ.get("HTTP_DESTINATION", "")
to_url = urlparse(raw_dest) to_url = urlparse(raw_dest)
@ -34,7 +37,7 @@ class ApplicationMoveMixin:
logger.info("Unsupported destination address: %r", raw_dest) logger.info("Unsupported destination address: %r", raw_dest)
# Remote destination server, not supported # Remote destination server, not supported
return httputils.REMOTE_DESTINATION return httputils.REMOTE_DESTINATION
access = app.Access(self._rights, user, path) access = Access(self._rights, user, path)
if not access.check("w"): if not access.check("w"):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
to_path = pathutils.sanitize_path(to_url.path) to_path = pathutils.sanitize_path(to_url.path)
@ -43,12 +46,12 @@ class ApplicationMoveMixin:
"start with base prefix", to_path, path) "start with base prefix", to_path, path)
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
to_path = to_path[len(base_prefix):] to_path = to_path[len(base_prefix):]
to_access = app.Access(self._rights, user, to_path) to_access = Access(self._rights, user, to_path)
if not to_access.check("w"): if not to_access.check("w"):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
with self._storage.acquire_lock("w", user): with self._storage.acquire_lock("w", user):
item = next(self._storage.discover(path), None) item = next(iter(self._storage.discover(path)), None)
if not item: if not item:
return httputils.NOT_FOUND return httputils.NOT_FOUND
if (not access.check("w", item) or if (not access.check("w", item) or
@ -58,17 +61,19 @@ class ApplicationMoveMixin:
# TODO: support moving collections # TODO: support moving collections
return httputils.METHOD_NOT_ALLOWED return httputils.METHOD_NOT_ALLOWED
to_item = next(self._storage.discover(to_path), None) to_item = next(iter(self._storage.discover(to_path)), None)
if isinstance(to_item, storage.BaseCollection): if isinstance(to_item, storage.BaseCollection):
return httputils.FORBIDDEN return httputils.FORBIDDEN
to_parent_path = pathutils.unstrip_path( to_parent_path = pathutils.unstrip_path(
posixpath.dirname(pathutils.strip_path(to_path)), True) posixpath.dirname(pathutils.strip_path(to_path)), True)
to_collection = next( to_collection = next(iter(
self._storage.discover(to_parent_path), None) self._storage.discover(to_parent_path)), None)
if not to_collection: if not to_collection:
return httputils.CONFLICT return httputils.CONFLICT
tag = item.collection.get_meta("tag") assert isinstance(to_collection, storage.BaseCollection)
if not tag or tag != to_collection.get_meta("tag"): assert item.collection is not None
collection_tag = item.collection.tag
if not collection_tag or collection_tag != to_collection.tag:
return httputils.FORBIDDEN return httputils.FORBIDDEN
if to_item and environ.get("HTTP_OVERWRITE", "F") != "T": if to_item and environ.get("HTTP_OVERWRITE", "F") != "T":
return httputils.PRECONDITION_FAILED return httputils.PRECONDITION_FAILED
@ -78,7 +83,7 @@ class ApplicationMoveMixin:
to_collection.has_uid(item.uid)): to_collection.has_uid(item.uid)):
return self._webdav_error_response( return self._webdav_error_response(
client.CONFLICT, "%s:no-uid-conflict" % ( client.CONFLICT, "%s:no-uid-conflict" % (
"C" if tag == "VCALENDAR" else "CR")) "C" if collection_tag == "VCALENDAR" else "CR"))
to_href = posixpath.basename(pathutils.strip_path(to_path)) to_href = posixpath.basename(pathutils.strip_path(to_path))
try: try:
self._storage.move(item, to_collection, to_href) self._storage.move(item, to_collection, to_href)

View File

@ -19,11 +19,14 @@
from http import client from http import client
from radicale import httputils from radicale import httputils, types
from radicale.app.base import ApplicationBase
class ApplicationOptionsMixin: class ApplicationPartOptions(ApplicationBase):
def do_OPTIONS(self, environ, base_prefix, path, user):
def do_OPTIONS(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage OPTIONS request.""" """Manage OPTIONS request."""
headers = { headers = {
"Allow": ", ".join( "Allow": ", ".join(

View File

@ -18,11 +18,14 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Radicale. If not, see <http://www.gnu.org/licenses/>. # along with Radicale. If not, see <http://www.gnu.org/licenses/>.
from radicale import httputils from radicale import httputils, types
from radicale.app.base import ApplicationBase
class ApplicationPostMixin: class ApplicationPartPost(ApplicationBase):
def do_POST(self, environ, base_prefix, path, user):
def do_POST(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage POST request.""" """Manage POST request."""
if path == "/.web" or path.startswith("/.web/"): if path == "/.web" or path.startswith("/.web/"):
return self._web.post(environ, base_prefix, path, user) return self._web.post(environ, base_prefix, path, user)

View File

@ -23,13 +23,17 @@ import posixpath
import socket import socket
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import client from http import client
from typing import Dict, Iterable, Iterator, List, Optional, Sequence, Tuple
from radicale import app, httputils, pathutils, rights, storage, xmlutils from radicale import httputils, pathutils, rights, storage, types, xmlutils
from radicale.app.base import Access, ApplicationBase
from radicale.log import logger from radicale.log import logger
def xml_propfind(base_prefix, path, xml_request, allowed_items, user, def xml_propfind(base_prefix: str, path: str,
encoding): xml_request: Optional[ET.Element],
allowed_items: Iterable[Tuple[types.CollectionOrItem, str]],
user: str, encoding: str) -> Optional[ET.Element]:
"""Read and answer PROPFIND requests. """Read and answer PROPFIND requests.
Read rfc4918-9.1 for info. Read rfc4918-9.1 for info.
@ -43,7 +47,7 @@ def xml_propfind(base_prefix, path, xml_request, allowed_items, user,
top_element = (xml_request[0] if xml_request is not None else top_element = (xml_request[0] if xml_request is not None else
ET.Element(xmlutils.make_clark("D:allprop"))) ET.Element(xmlutils.make_clark("D:allprop")))
props = () props: List[str] = []
allprop = False allprop = False
propname = False propname = False
if top_element.tag == xmlutils.make_clark("D:allprop"): if top_element.tag == xmlutils.make_clark("D:allprop"):
@ -51,13 +55,13 @@ def xml_propfind(base_prefix, path, xml_request, allowed_items, user,
elif top_element.tag == xmlutils.make_clark("D:propname"): elif top_element.tag == xmlutils.make_clark("D:propname"):
propname = True propname = True
elif top_element.tag == xmlutils.make_clark("D:prop"): elif top_element.tag == xmlutils.make_clark("D:prop"):
props = [prop.tag for prop in top_element] props.extend(prop.tag for prop in top_element)
if xmlutils.make_clark("D:current-user-principal") in props and not user: if xmlutils.make_clark("D:current-user-principal") in props and not user:
# Ask for authentication # Ask for authentication
# Returning the DAV:unauthenticated pseudo-principal as specified in # Returning the DAV:unauthenticated pseudo-principal as specified in
# RFC 5397 doesn't seem to work with DAVx5. # RFC 5397 doesn't seem to work with DAVx5.
return client.FORBIDDEN, None return None
# Writing answer # Writing answer
multistatus = ET.Element(xmlutils.make_clark("D:multistatus")) multistatus = ET.Element(xmlutils.make_clark("D:multistatus"))
@ -68,29 +72,32 @@ def xml_propfind(base_prefix, path, xml_request, allowed_items, user,
base_prefix, path, item, props, user, encoding, write=write, base_prefix, path, item, props, user, encoding, write=write,
allprop=allprop, propname=propname)) allprop=allprop, propname=propname))
return client.MULTI_STATUS, multistatus return multistatus
def xml_propfind_response(base_prefix, path, item, props, user, encoding, def xml_propfind_response(
write=False, propname=False, allprop=False): base_prefix: str, path: str, item: types.CollectionOrItem,
props: Sequence[str], user: str, encoding: str, write: bool = False,
propname: bool = False, allprop: bool = False) -> ET.Element:
"""Build and return a PROPFIND response.""" """Build and return a PROPFIND response."""
if propname and allprop or (props and (propname or allprop)): if propname and allprop or (props and (propname or allprop)):
raise ValueError("Only use one of props, propname and allprops") raise ValueError("Only use one of props, propname and allprops")
is_collection = isinstance(item, storage.BaseCollection)
if is_collection:
is_leaf = item.get_meta("tag") in ("VADDRESSBOOK", "VCALENDAR")
collection = item
else:
collection = item.collection
response = ET.Element(xmlutils.make_clark("D:response")) if isinstance(item, storage.BaseCollection):
href = ET.Element(xmlutils.make_clark("D:href")) is_collection = True
if is_collection: is_leaf = item.tag in ("VADDRESSBOOK", "VCALENDAR")
# Some clients expect collections to end with / collection = item
# Some clients expect collections to end with `/`
uri = pathutils.unstrip_path(item.path, True) uri = pathutils.unstrip_path(item.path, True)
else: else:
uri = pathutils.unstrip_path( is_collection = is_leaf = False
posixpath.join(collection.path, item.href)) assert item.collection is not None
assert item.href
collection = item.collection
uri = pathutils.unstrip_path(posixpath.join(
collection.path, item.href))
response = ET.Element(xmlutils.make_clark("D:response"))
href = ET.Element(xmlutils.make_clark("D:href"))
href.text = xmlutils.make_href(base_prefix, uri) href.text = xmlutils.make_href(base_prefix, uri)
response.append(href) response.append(href)
@ -120,12 +127,12 @@ def xml_propfind_response(base_prefix, path, item, props, user, encoding,
if is_leaf: if is_leaf:
props.append(xmlutils.make_clark("D:displayname")) props.append(xmlutils.make_clark("D:displayname"))
props.append(xmlutils.make_clark("D:sync-token")) props.append(xmlutils.make_clark("D:sync-token"))
if collection.get_meta("tag") == "VCALENDAR": if collection.tag == "VCALENDAR":
props.append(xmlutils.make_clark("CS:getctag")) props.append(xmlutils.make_clark("CS:getctag"))
props.append( props.append(
xmlutils.make_clark("C:supported-calendar-component-set")) xmlutils.make_clark("C:supported-calendar-component-set"))
meta = item.get_meta() meta = collection.get_meta()
for tag in meta: for tag in meta:
if tag == "tag": if tag == "tag":
continue continue
@ -133,11 +140,11 @@ def xml_propfind_response(base_prefix, path, item, props, user, encoding,
if clark_tag not in props: if clark_tag not in props:
props.append(clark_tag) props.append(clark_tag)
responses = collections.defaultdict(list) responses: Dict[int, List[ET.Element]] = collections.defaultdict(list)
if propname: if propname:
for tag in props: for tag in props:
responses[200].append(ET.Element(tag)) responses[200].append(ET.Element(tag))
props = () props = []
for tag in props: for tag in props:
element = ET.Element(tag) element = ET.Element(tag)
is404 = False is404 = False
@ -159,18 +166,18 @@ def xml_propfind_response(base_prefix, path, item, props, user, encoding,
xmlutils.make_clark("D:principal-URL"), xmlutils.make_clark("D:principal-URL"),
xmlutils.make_clark("CR:addressbook-home-set"), xmlutils.make_clark("CR:addressbook-home-set"),
xmlutils.make_clark("C:calendar-home-set")) and xmlutils.make_clark("C:calendar-home-set")) and
collection.is_principal and is_collection): is_collection and collection.is_principal):
child_element = ET.Element(xmlutils.make_clark("D:href")) child_element = ET.Element(xmlutils.make_clark("D:href"))
child_element.text = xmlutils.make_href(base_prefix, path) child_element.text = xmlutils.make_href(base_prefix, path)
element.append(child_element) element.append(child_element)
elif tag == xmlutils.make_clark("C:supported-calendar-component-set"): elif tag == xmlutils.make_clark("C:supported-calendar-component-set"):
human_tag = xmlutils.make_human_tag(tag) human_tag = xmlutils.make_human_tag(tag)
if is_collection and is_leaf: if is_collection and is_leaf:
meta = item.get_meta(human_tag) components_text = collection.get_meta(human_tag)
if meta: if components_text:
components = meta.split(",") components = components_text.split(",")
else: else:
components = ("VTODO", "VEVENT", "VJOURNAL") components = ["VTODO", "VEVENT", "VJOURNAL"]
for component in components: for component in components:
comp = ET.Element(xmlutils.make_clark("C:comp")) comp = ET.Element(xmlutils.make_clark("C:comp"))
comp.set("name", component) comp.set("name", component)
@ -205,10 +212,10 @@ def xml_propfind_response(base_prefix, path, item, props, user, encoding,
"D:principal-property-search"] "D:principal-property-search"]
if is_collection and is_leaf: if is_collection and is_leaf:
reports.append("D:sync-collection") reports.append("D:sync-collection")
if item.get_meta("tag") == "VADDRESSBOOK": if collection.tag == "VADDRESSBOOK":
reports.append("CR:addressbook-multiget") reports.append("CR:addressbook-multiget")
reports.append("CR:addressbook-query") reports.append("CR:addressbook-query")
elif item.get_meta("tag") == "VCALENDAR": elif collection.tag == "VCALENDAR":
reports.append("C:calendar-multiget") reports.append("C:calendar-multiget")
reports.append("C:calendar-query") reports.append("C:calendar-query")
for human_tag in reports: for human_tag in reports:
@ -234,20 +241,21 @@ def xml_propfind_response(base_prefix, path, item, props, user, encoding,
elif is_collection: elif is_collection:
if tag == xmlutils.make_clark("D:getcontenttype"): if tag == xmlutils.make_clark("D:getcontenttype"):
if is_leaf: if is_leaf:
element.text = xmlutils.MIMETYPES[item.get_meta("tag")] element.text = xmlutils.MIMETYPES[
collection.tag]
else: else:
is404 = True is404 = True
elif tag == xmlutils.make_clark("D:resourcetype"): elif tag == xmlutils.make_clark("D:resourcetype"):
if item.is_principal: if collection.is_principal:
child_element = ET.Element( child_element = ET.Element(
xmlutils.make_clark("D:principal")) xmlutils.make_clark("D:principal"))
element.append(child_element) element.append(child_element)
if is_leaf: if is_leaf:
if item.get_meta("tag") == "VADDRESSBOOK": if collection.tag == "VADDRESSBOOK":
child_element = ET.Element( child_element = ET.Element(
xmlutils.make_clark("CR:addressbook")) xmlutils.make_clark("CR:addressbook"))
element.append(child_element) element.append(child_element)
elif item.get_meta("tag") == "VCALENDAR": elif collection.tag == "VCALENDAR":
child_element = ET.Element( child_element = ET.Element(
xmlutils.make_clark("C:calendar")) xmlutils.make_clark("C:calendar"))
element.append(child_element) element.append(child_element)
@ -255,38 +263,39 @@ def xml_propfind_response(base_prefix, path, item, props, user, encoding,
element.append(child_element) element.append(child_element)
elif tag == xmlutils.make_clark("RADICALE:displayname"): elif tag == xmlutils.make_clark("RADICALE:displayname"):
# Only for internal use by the web interface # Only for internal use by the web interface
displayname = item.get_meta("D:displayname") displayname = collection.get_meta("D:displayname")
if displayname is not None: if displayname is not None:
element.text = displayname element.text = displayname
else: else:
is404 = True is404 = True
elif tag == xmlutils.make_clark("D:displayname"): elif tag == xmlutils.make_clark("D:displayname"):
displayname = item.get_meta("D:displayname") displayname = collection.get_meta("D:displayname")
if not displayname and is_leaf: if not displayname and is_leaf:
displayname = item.path displayname = collection.path
if displayname is not None: if displayname is not None:
element.text = displayname element.text = displayname
else: else:
is404 = True is404 = True
elif tag == xmlutils.make_clark("CS:getctag"): elif tag == xmlutils.make_clark("CS:getctag"):
if is_leaf: if is_leaf:
element.text = item.etag element.text = collection.etag
else: else:
is404 = True is404 = True
elif tag == xmlutils.make_clark("D:sync-token"): elif tag == xmlutils.make_clark("D:sync-token"):
if is_leaf: if is_leaf:
element.text, _ = item.sync() element.text, _ = collection.sync()
else: else:
is404 = True is404 = True
else: else:
human_tag = xmlutils.make_human_tag(tag) human_tag = xmlutils.make_human_tag(tag)
meta = item.get_meta(human_tag) tag_text = collection.get_meta(human_tag)
if meta is not None: if tag_text is not None:
element.text = meta element.text = tag_text
else: else:
is404 = True is404 = True
# Not for collections # Not for collections
elif tag == xmlutils.make_clark("D:getcontenttype"): elif tag == xmlutils.make_clark("D:getcontenttype"):
assert not isinstance(item, storage.BaseCollection)
element.text = xmlutils.get_content_type(item, encoding) element.text = xmlutils.get_content_type(item, encoding)
elif tag == xmlutils.make_clark("D:resourcetype"): elif tag == xmlutils.make_clark("D:resourcetype"):
# resourcetype must be returned empty for non-collection elements # resourcetype must be returned empty for non-collection elements
@ -311,13 +320,16 @@ def xml_propfind_response(base_prefix, path, item, props, user, encoding,
return response return response
class ApplicationPropfindMixin: class ApplicationPartPropfind(ApplicationBase):
def _collect_allowed_items(self, items, user):
def _collect_allowed_items(
self, items: Iterable[types.CollectionOrItem], user: str
) -> Iterator[Tuple[types.CollectionOrItem, str]]:
"""Get items from request that user is allowed to access.""" """Get items from request that user is allowed to access."""
for item in items: for item in items:
if isinstance(item, storage.BaseCollection): if isinstance(item, storage.BaseCollection):
path = pathutils.unstrip_path(item.path, True) path = pathutils.unstrip_path(item.path, True)
if item.get_meta("tag"): if item.tag:
permissions = rights.intersect( permissions = rights.intersect(
self._rights.authorization(user, path), "rw") self._rights.authorization(user, path), "rw")
target = "collection with tag %r" % item.path target = "collection with tag %r" % item.path
@ -326,6 +338,7 @@ class ApplicationPropfindMixin:
self._rights.authorization(user, path), "RW") self._rights.authorization(user, path), "RW")
target = "collection %r" % item.path target = "collection %r" % item.path
else: else:
assert item.collection is not None
path = pathutils.unstrip_path(item.collection.path, True) path = pathutils.unstrip_path(item.collection.path, True)
permissions = rights.intersect( permissions = rights.intersect(
self._rights.authorization(user, path), "rw") self._rights.authorization(user, path), "rw")
@ -345,9 +358,10 @@ class ApplicationPropfindMixin:
if permission: if permission:
yield item, permission yield item, permission
def do_PROPFIND(self, environ, base_prefix, path, user): def do_PROPFIND(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage PROPFIND request.""" """Manage PROPFIND request."""
access = app.Access(self._rights, user, path) access = Access(self._rights, user, path)
if not access.check("r"): if not access.check("r"):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
try: try:
@ -360,22 +374,21 @@ class ApplicationPropfindMixin:
logger.debug("Client timed out", exc_info=True) logger.debug("Client timed out", exc_info=True)
return httputils.REQUEST_TIMEOUT return httputils.REQUEST_TIMEOUT
with self._storage.acquire_lock("r", user): with self._storage.acquire_lock("r", user):
items = self._storage.discover( items_iter = iter(self._storage.discover(
path, environ.get("HTTP_DEPTH", "0")) path, environ.get("HTTP_DEPTH", "0")))
# take root item for rights checking # take root item for rights checking
item = next(items, None) item = next(items_iter, None)
if not item: if not item:
return httputils.NOT_FOUND return httputils.NOT_FOUND
if not access.check("r", item): if not access.check("r", item):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
# put item back # put item back
items = itertools.chain([item], items) items_iter = itertools.chain([item], items_iter)
allowed_items = self._collect_allowed_items(items, user) allowed_items = self._collect_allowed_items(items_iter, user)
headers = {"DAV": httputils.DAV_HEADERS, headers = {"DAV": httputils.DAV_HEADERS,
"Content-Type": "text/xml; charset=%s" % self._encoding} "Content-Type": "text/xml; charset=%s" % self._encoding}
status, xml_answer = xml_propfind( xml_answer = xml_propfind(base_prefix, path, xml_content,
base_prefix, path, xml_content, allowed_items, user, allowed_items, user, self._encoding)
self._encoding) if xml_answer is None:
if status == client.FORBIDDEN and xml_answer is None:
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
return status, headers, self._xml_response(xml_answer) return client.MULTI_STATUS, headers, self._xml_response(xml_answer)

View File

@ -17,18 +17,20 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Radicale. If not, see <http://www.gnu.org/licenses/>. # along with Radicale. If not, see <http://www.gnu.org/licenses/>.
import contextlib
import socket import socket
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import client from http import client
from typing import Dict, Optional, cast
from radicale import app, httputils import radicale.item as radicale_item
from radicale import item as radicale_item from radicale import httputils, storage, types, xmlutils
from radicale import storage, xmlutils from radicale.app.base import Access, ApplicationBase
from radicale.log import logger from radicale.log import logger
def xml_proppatch(base_prefix, path, xml_request, collection): def xml_proppatch(base_prefix: str, path: str,
xml_request: Optional[ET.Element],
collection: storage.BaseCollection) -> ET.Element:
"""Read and answer PROPPATCH requests. """Read and answer PROPPATCH requests.
Read rfc4918-9.2 for info. Read rfc4918-9.2 for info.
@ -49,24 +51,24 @@ def xml_proppatch(base_prefix, path, xml_request, collection):
propstat.append(status) propstat.append(status)
response.append(propstat) response.append(propstat)
new_props = collection.get_meta() props_with_remove = xmlutils.props_from_request(xml_request)
for short_name, value in xmlutils.props_from_request(xml_request).items(): all_props_with_remove = cast(Dict[str, Optional[str]],
if value is None: dict(collection.get_meta()))
with contextlib.suppress(KeyError): all_props_with_remove.update(props_with_remove)
del new_props[short_name] all_props = radicale_item.check_and_sanitize_props(all_props_with_remove)
else: collection.set_meta(all_props)
new_props[short_name] = value for short_name in props_with_remove:
props_ok.append(ET.Element(xmlutils.make_clark(short_name))) props_ok.append(ET.Element(xmlutils.make_clark(short_name)))
radicale_item.check_and_sanitize_props(new_props)
collection.set_meta(new_props)
return multistatus return multistatus
class ApplicationProppatchMixin: class ApplicationPartProppatch(ApplicationBase):
def do_PROPPATCH(self, environ, base_prefix, path, user):
def do_PROPPATCH(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage PROPPATCH request.""" """Manage PROPPATCH request."""
access = app.Access(self._rights, user, path) access = Access(self._rights, user, path)
if not access.check("w"): if not access.check("w"):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
try: try:
@ -79,7 +81,7 @@ class ApplicationProppatchMixin:
logger.debug("Client timed out", exc_info=True) logger.debug("Client timed out", exc_info=True)
return httputils.REQUEST_TIMEOUT return httputils.REQUEST_TIMEOUT
with self._storage.acquire_lock("w", user): with self._storage.acquire_lock("w", user):
item = next(self._storage.discover(path), None) item = next(iter(self._storage.discover(path)), None)
if not item: if not item:
return httputils.NOT_FOUND return httputils.NOT_FOUND
if not access.check("w", item): if not access.check("w", item):

View File

@ -22,20 +22,30 @@ import posixpath
import socket import socket
import sys import sys
from http import client from http import client
from types import TracebackType
from typing import Iterator, List, Mapping, MutableMapping, Optional, Tuple
import vobject import vobject
from radicale import app, httputils import radicale.item as radicale_item
from radicale import item as radicale_item from radicale import httputils, pathutils, rights, storage, types, xmlutils
from radicale import pathutils, rights, storage, xmlutils from radicale.app.base import Access, ApplicationBase
from radicale.log import logger from radicale.log import logger
MIMETYPE_TAGS = {value: key for key, value in xmlutils.MIMETYPES.items()} MIMETYPE_TAGS: Mapping[str, str] = {value: key for key, value in
xmlutils.MIMETYPES.items()}
def prepare(vobject_items, path, content_type, permissions, parent_permissions, def prepare(vobject_items: List[vobject.base.Component], path: str,
tag=None, write_whole_collection=None): content_type: str, permission: bool, parent_permission: bool,
if (write_whole_collection or permissions and not parent_permissions): tag: Optional[str] = None,
write_whole_collection: Optional[bool] = None) -> Tuple[
Iterator[radicale_item.Item], # items
Optional[str], # tag
Optional[bool], # write_whole_collection
Optional[MutableMapping[str, str]], # props
Optional[Tuple[type, BaseException, Optional[TracebackType]]]]:
if (write_whole_collection or permission and not parent_permission):
write_whole_collection = True write_whole_collection = True
tag = radicale_item.predict_tag_of_whole_collection( tag = radicale_item.predict_tag_of_whole_collection(
vobject_items, MIMETYPE_TAGS.get(content_type)) vobject_items, MIMETYPE_TAGS.get(content_type))
@ -43,20 +53,20 @@ def prepare(vobject_items, path, content_type, permissions, parent_permissions,
raise ValueError("Can't determine collection tag") raise ValueError("Can't determine collection tag")
collection_path = pathutils.strip_path(path) collection_path = pathutils.strip_path(path)
elif (write_whole_collection is not None and not write_whole_collection or elif (write_whole_collection is not None and not write_whole_collection or
not permissions and parent_permissions): not permission and parent_permission):
write_whole_collection = False write_whole_collection = False
if tag is None: if tag is None:
tag = radicale_item.predict_tag_of_parent_collection(vobject_items) tag = radicale_item.predict_tag_of_parent_collection(vobject_items)
collection_path = posixpath.dirname(pathutils.strip_path(path)) collection_path = posixpath.dirname(pathutils.strip_path(path))
props = None props: Optional[MutableMapping[str, str]] = None
stored_exc_info = None stored_exc_info = None
items = [] items = []
try: try:
if tag: if tag and write_whole_collection is not None:
radicale_item.check_and_sanitize_items( radicale_item.check_and_sanitize_items(
vobject_items, is_collection=write_whole_collection, tag=tag) vobject_items, is_collection=write_whole_collection, tag=tag)
if write_whole_collection and tag == "VCALENDAR": if write_whole_collection and tag == "VCALENDAR":
vobject_components = [] vobject_components: List[vobject.base.Component] = []
vobject_item, = vobject_items vobject_item, = vobject_items
for content in ("vevent", "vtodo", "vjournal"): for content in ("vevent", "vtodo", "vjournal"):
vobject_components.extend( vobject_components.extend(
@ -98,23 +108,25 @@ def prepare(vobject_items, path, content_type, permissions, parent_permissions,
caldesc = vobject_items[0].x_wr_caldesc.value caldesc = vobject_items[0].x_wr_caldesc.value
if caldesc: if caldesc:
props["C:calendar-description"] = caldesc props["C:calendar-description"] = caldesc
radicale_item.check_and_sanitize_props(props) props = radicale_item.check_and_sanitize_props(props)
except Exception: except Exception:
stored_exc_info = sys.exc_info() exc_info_or_none_tuple = sys.exc_info()
assert exc_info_or_none_tuple[0] is not None
stored_exc_info = exc_info_or_none_tuple
# Use generator for items and delete references to free memory # Use iterator for items and delete references to free memory early
# early def items_iter() -> Iterator[radicale_item.Item]:
def items_generator():
while items: while items:
yield items.pop(0) yield items.pop(0)
return (items_generator(), tag, write_whole_collection, props, return items_iter(), tag, write_whole_collection, props, stored_exc_info
stored_exc_info)
class ApplicationPutMixin: class ApplicationPartPut(ApplicationBase):
def do_PUT(self, environ, base_prefix, path, user):
def do_PUT(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage PUT request.""" """Manage PUT request."""
access = app.Access(self._rights, user, path) access = Access(self._rights, user, path)
if not access.check("w"): if not access.check("w"):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
try: try:
@ -126,9 +138,10 @@ class ApplicationPutMixin:
logger.debug("Client timed out", exc_info=True) logger.debug("Client timed out", exc_info=True)
return httputils.REQUEST_TIMEOUT return httputils.REQUEST_TIMEOUT
# Prepare before locking # Prepare before locking
content_type = environ.get("CONTENT_TYPE", "").split(";")[0] content_type = environ.get("CONTENT_TYPE", "").split(";",
maxsplit=1)[0]
try: try:
vobject_items = tuple(vobject.readComponents(content or "")) vobject_items = list(vobject.readComponents(content or ""))
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Bad PUT request on %r: %s", path, e, exc_info=True) "Bad PUT request on %r: %s", path, e, exc_info=True)
@ -140,20 +153,20 @@ class ApplicationPutMixin:
bool(rights.intersect(access.parent_permissions, "w"))) bool(rights.intersect(access.parent_permissions, "w")))
with self._storage.acquire_lock("w", user): with self._storage.acquire_lock("w", user):
item = next(self._storage.discover(path), None) item = next(iter(self._storage.discover(path)), None)
parent_item = next( parent_item = next(iter(
self._storage.discover(access.parent_path), None) self._storage.discover(access.parent_path)), None)
if not parent_item: if not isinstance(parent_item, storage.BaseCollection):
return httputils.CONFLICT return httputils.CONFLICT
write_whole_collection = ( write_whole_collection = (
isinstance(item, storage.BaseCollection) or isinstance(item, storage.BaseCollection) or
not parent_item.get_meta("tag")) not parent_item.tag)
if write_whole_collection: if write_whole_collection:
tag = prepared_tag tag = prepared_tag
else: else:
tag = parent_item.get_meta("tag") tag = parent_item.tag
if write_whole_collection: if write_whole_collection:
if ("w" if tag else "W") not in access.permissions: if ("w" if tag else "W") not in access.permissions:
@ -198,6 +211,7 @@ class ApplicationPutMixin:
"Bad PUT request on %r: %s", path, e, exc_info=True) "Bad PUT request on %r: %s", path, e, exc_info=True)
return httputils.BAD_REQUEST return httputils.BAD_REQUEST
else: else:
assert not isinstance(item, storage.BaseCollection)
prepared_item, = prepared_items prepared_item, = prepared_items
if (item and item.uid != prepared_item.uid or if (item and item.uid != prepared_item.uid or
not item and parent_item.has_uid(prepared_item.uid)): not item and parent_item.has_uid(prepared_item.uid)):

View File

@ -22,15 +22,20 @@ import posixpath
import socket import socket
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import client from http import client
from typing import Callable, Iterable, Iterator, Optional, Sequence, Tuple
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
from radicale import app, httputils, pathutils, storage, xmlutils import radicale.item as radicale_item
from radicale import httputils, pathutils, storage, types, xmlutils
from radicale.app.base import Access, ApplicationBase
from radicale.item import filter as radicale_filter from radicale.item import filter as radicale_filter
from radicale.log import logger from radicale.log import logger
def xml_report(base_prefix, path, xml_request, collection, encoding, def xml_report(base_prefix: str, path: str, xml_request: Optional[ET.Element],
unlock_storage_fn): collection: storage.BaseCollection, encoding: str,
unlock_storage_fn: Callable[[], None]
) -> Tuple[int, ET.Element]:
"""Read and answer REPORT requests. """Read and answer REPORT requests.
Read rfc3253-3.6 for info. Read rfc3253-3.6 for info.
@ -40,10 +45,9 @@ def xml_report(base_prefix, path, xml_request, collection, encoding,
if xml_request is None: if xml_request is None:
return client.MULTI_STATUS, multistatus return client.MULTI_STATUS, multistatus
root = xml_request root = xml_request
if root.tag in ( if root.tag in (xmlutils.make_clark("D:principal-search-property-set"),
xmlutils.make_clark("D:principal-search-property-set"), xmlutils.make_clark("D:principal-property-search"),
xmlutils.make_clark("D:principal-property-search"), xmlutils.make_clark("D:expand-property")):
xmlutils.make_clark("D:expand-property")):
# We don't support searching for principals or indirect retrieving of # We don't support searching for principals or indirect retrieving of
# properties, just return an empty result. # properties, just return an empty result.
# InfCloud asks for expand-property reports (even if we don't announce # InfCloud asks for expand-property reports (even if we don't announce
@ -52,28 +56,28 @@ def xml_report(base_prefix, path, xml_request, collection, encoding,
xmlutils.make_human_tag(root.tag), path) xmlutils.make_human_tag(root.tag), path)
return client.MULTI_STATUS, multistatus return client.MULTI_STATUS, multistatus
if (root.tag == xmlutils.make_clark("C:calendar-multiget") and if (root.tag == xmlutils.make_clark("C:calendar-multiget") and
collection.get_meta("tag") != "VCALENDAR" or collection.tag != "VCALENDAR" or
root.tag == xmlutils.make_clark("CR:addressbook-multiget") and root.tag == xmlutils.make_clark("CR:addressbook-multiget") and
collection.get_meta("tag") != "VADDRESSBOOK" or collection.tag != "VADDRESSBOOK" or
root.tag == xmlutils.make_clark("D:sync-collection") and root.tag == xmlutils.make_clark("D:sync-collection") and
collection.get_meta("tag") not in ("VADDRESSBOOK", "VCALENDAR")): collection.tag not in ("VADDRESSBOOK", "VCALENDAR")):
logger.warning("Invalid REPORT method %r on %r requested", logger.warning("Invalid REPORT method %r on %r requested",
xmlutils.make_human_tag(root.tag), path) xmlutils.make_human_tag(root.tag), path)
return (client.FORBIDDEN, return client.FORBIDDEN, xmlutils.webdav_error("D:supported-report")
xmlutils.webdav_error("D:supported-report"))
prop_element = root.find(xmlutils.make_clark("D:prop")) prop_element = root.find(xmlutils.make_clark("D:prop"))
props = ( props = ([prop.tag for prop in prop_element]
[prop.tag for prop in prop_element] if prop_element is not None else [])
if prop_element is not None else [])
hreferences: Iterable[str]
if root.tag in ( if root.tag in (
xmlutils.make_clark("C:calendar-multiget"), xmlutils.make_clark("C:calendar-multiget"),
xmlutils.make_clark("CR:addressbook-multiget")): xmlutils.make_clark("CR:addressbook-multiget")):
# Read rfc4791-7.9 for info # Read rfc4791-7.9 for info
hreferences = set() hreferences = set()
for href_element in root.findall(xmlutils.make_clark("D:href")): for href_element in root.findall(xmlutils.make_clark("D:href")):
href_path = pathutils.sanitize_path( temp_url_path = urlparse(href_element.text).path
unquote(urlparse(href_element.text).path)) assert isinstance(temp_url_path, str)
href_path = pathutils.sanitize_path(unquote(temp_url_path))
if (href_path + "/").startswith(base_prefix + "/"): if (href_path + "/").startswith(base_prefix + "/"):
hreferences.add(href_path[len(base_prefix):]) hreferences.add(href_path[len(base_prefix):])
else: else:
@ -107,82 +111,13 @@ def xml_report(base_prefix, path, xml_request, collection, encoding,
root.findall(xmlutils.make_clark("C:filter")) + root.findall(xmlutils.make_clark("C:filter")) +
root.findall(xmlutils.make_clark("CR:filter"))) root.findall(xmlutils.make_clark("CR:filter")))
def retrieve_items(collection, hreferences, multistatus):
"""Retrieves all items that are referenced in ``hreferences`` from
``collection`` and adds 404 responses for missing and invalid items
to ``multistatus``."""
collection_requested = False
def get_names():
"""Extracts all names from references in ``hreferences`` and adds
404 responses for invalid references to ``multistatus``.
If the whole collections is referenced ``collection_requested``
gets set to ``True``."""
nonlocal collection_requested
for hreference in hreferences:
try:
name = pathutils.name_from_path(hreference, collection)
except ValueError as e:
logger.warning("Skipping invalid path %r in REPORT request"
" on %r: %s", hreference, path, e)
response = xml_item_response(base_prefix, hreference,
found_item=False)
multistatus.append(response)
continue
if name:
# Reference is an item
yield name
else:
# Reference is a collection
collection_requested = True
for name, item in collection.get_multi(get_names()):
if not item:
uri = pathutils.unstrip_path(
posixpath.join(collection.path, name))
response = xml_item_response(base_prefix, uri,
found_item=False)
multistatus.append(response)
else:
yield item, False
if collection_requested:
yield from collection.get_filtered(filters)
# Retrieve everything required for finishing the request. # Retrieve everything required for finishing the request.
retrieved_items = list(retrieve_items(collection, hreferences, retrieved_items = list(retrieve_items(
multistatus)) base_prefix, path, collection, hreferences, filters, multistatus))
collection_tag = collection.get_meta("tag") collection_tag = collection.tag
# Don't access storage after this! # !!! Don't access storage after this !!!
unlock_storage_fn() unlock_storage_fn()
def match(item, filter_):
tag = collection_tag
if (tag == "VCALENDAR" and
filter_.tag != xmlutils.make_clark("C:%s" % filter_)):
if len(filter_) == 0:
return True
if len(filter_) > 1:
raise ValueError("Filter with %d children" % len(filter_))
if filter_[0].tag != xmlutils.make_clark("C:comp-filter"):
raise ValueError("Unexpected %r in filter" % filter_[0].tag)
return radicale_filter.comp_match(item, filter_[0])
if (tag == "VADDRESSBOOK" and
filter_.tag != xmlutils.make_clark("CR:%s" % filter_)):
for child in filter_:
if child.tag != xmlutils.make_clark("CR:prop-filter"):
raise ValueError("Unexpected %r in filter" % child.tag)
test = filter_.get("test", "anyof")
if test == "anyof":
return any(
radicale_filter.prop_match(item.vobject_item, f, "CR")
for f in filter_)
if test == "allof":
return all(
radicale_filter.prop_match(item.vobject_item, f, "CR")
for f in filter_)
raise ValueError("Unsupported filter test: %r" % test)
raise ValueError("Unsupported filter %r for %r" % (filter_.tag, tag))
while retrieved_items: while retrieved_items:
# ``item.vobject_item`` might be accessed during filtering. # ``item.vobject_item`` might be accessed during filtering.
# Don't keep reference to ``item``, because VObject requires a lot of # Don't keep reference to ``item``, because VObject requires a lot of
@ -190,7 +125,8 @@ def xml_report(base_prefix, path, xml_request, collection, encoding,
item, filters_matched = retrieved_items.pop(0) item, filters_matched = retrieved_items.pop(0)
if filters and not filters_matched: if filters and not filters_matched:
try: try:
if not all(match(item, filter_) for filter_ in filters): if not all(test_filter(collection_tag, item, filter_)
for filter_ in filters):
continue continue
except ValueError as e: except ValueError as e:
raise ValueError("Failed to filter item %r from %r: %s" % raise ValueError("Failed to filter item %r from %r: %s" %
@ -218,6 +154,7 @@ def xml_report(base_prefix, path, xml_request, collection, encoding,
else: else:
not_found_props.append(element) not_found_props.append(element)
assert item.href
uri = pathutils.unstrip_path( uri = pathutils.unstrip_path(
posixpath.join(collection.path, item.href)) posixpath.join(collection.path, item.href))
multistatus.append(xml_item_response( multistatus.append(xml_item_response(
@ -227,8 +164,10 @@ def xml_report(base_prefix, path, xml_request, collection, encoding,
return client.MULTI_STATUS, multistatus return client.MULTI_STATUS, multistatus
def xml_item_response(base_prefix, href, found_props=(), not_found_props=(), def xml_item_response(base_prefix: str, href: str,
found_item=True): found_props: Sequence[ET.Element] = (),
not_found_props: Sequence[ET.Element] = (),
found_item: bool = True) -> ET.Element:
response = ET.Element(xmlutils.make_clark("D:response")) response = ET.Element(xmlutils.make_clark("D:response"))
href_element = ET.Element(xmlutils.make_clark("D:href")) href_element = ET.Element(xmlutils.make_clark("D:href"))
@ -255,24 +194,98 @@ def xml_item_response(base_prefix, href, found_props=(), not_found_props=(),
return response return response
class ApplicationReportMixin: def retrieve_items(
def do_REPORT(self, environ, base_prefix, path, user): base_prefix: str, path: str, collection: storage.BaseCollection,
hreferences: Iterable[str], filters: Sequence[ET.Element],
multistatus: ET.Element) -> Iterator[Tuple[radicale_item.Item, bool]]:
"""Retrieves all items that are referenced in ``hreferences`` from
``collection`` and adds 404 responses for missing and invalid items
to ``multistatus``."""
collection_requested = False
def get_names() -> Iterator[str]:
"""Extracts all names from references in ``hreferences`` and adds
404 responses for invalid references to ``multistatus``.
If the whole collections is referenced ``collection_requested``
gets set to ``True``."""
nonlocal collection_requested
for hreference in hreferences:
try:
name = pathutils.name_from_path(hreference, collection)
except ValueError as e:
logger.warning("Skipping invalid path %r in REPORT request on "
"%r: %s", hreference, path, e)
response = xml_item_response(base_prefix, hreference,
found_item=False)
multistatus.append(response)
continue
if name:
# Reference is an item
yield name
else:
# Reference is a collection
collection_requested = True
for name, item in collection.get_multi(get_names()):
if not item:
uri = pathutils.unstrip_path(posixpath.join(collection.path, name))
response = xml_item_response(base_prefix, uri, found_item=False)
multistatus.append(response)
else:
yield item, False
if collection_requested:
yield from collection.get_filtered(filters)
def test_filter(collection_tag: str, item: radicale_item.Item,
filter_: ET.Element) -> bool:
"""Match an item against a filter."""
if (collection_tag == "VCALENDAR" and
filter_.tag != xmlutils.make_clark("C:%s" % filter_)):
if len(filter_) == 0:
return True
if len(filter_) > 1:
raise ValueError("Filter with %d children" % len(filter_))
if filter_[0].tag != xmlutils.make_clark("C:comp-filter"):
raise ValueError("Unexpected %r in filter" % filter_[0].tag)
return radicale_filter.comp_match(item, filter_[0])
if (collection_tag == "VADDRESSBOOK" and
filter_.tag != xmlutils.make_clark("CR:%s" % filter_)):
for child in filter_:
if child.tag != xmlutils.make_clark("CR:prop-filter"):
raise ValueError("Unexpected %r in filter" % child.tag)
test = filter_.get("test", "anyof")
if test == "anyof":
return any(radicale_filter.prop_match(item.vobject_item, f, "CR")
for f in filter_)
if test == "allof":
return all(radicale_filter.prop_match(item.vobject_item, f, "CR")
for f in filter_)
raise ValueError("Unsupported filter test: %r" % test)
raise ValueError("Unsupported filter %r for %r" %
(filter_.tag, collection_tag))
class ApplicationPartReport(ApplicationBase):
def do_REPORT(self, environ: types.WSGIEnviron, base_prefix: str,
path: str, user: str) -> types.WSGIResponse:
"""Manage REPORT request.""" """Manage REPORT request."""
access = app.Access(self._rights, user, path) access = Access(self._rights, user, path)
if not access.check("r"): if not access.check("r"):
return httputils.NOT_ALLOWED return httputils.NOT_ALLOWED
try: try:
xml_content = self._read_xml_request_body(environ) xml_content = self._read_xml_request_body(environ)
except RuntimeError as e: except RuntimeError as e:
logger.warning( logger.warning("Bad REPORT request on %r: %s", path, e,
"Bad REPORT request on %r: %s", path, e, exc_info=True) exc_info=True)
return httputils.BAD_REQUEST return httputils.BAD_REQUEST
except socket.timeout: except socket.timeout:
logger.debug("Client timed out", exc_info=True) logger.debug("Client timed out", exc_info=True)
return httputils.REQUEST_TIMEOUT return httputils.REQUEST_TIMEOUT
with contextlib.ExitStack() as lock_stack: with contextlib.ExitStack() as lock_stack:
lock_stack.enter_context(self._storage.acquire_lock("r", user)) lock_stack.enter_context(self._storage.acquire_lock("r", user))
item = next(self._storage.discover(path), None) item = next(iter(self._storage.discover(path)), None)
if not item: if not item:
return httputils.NOT_FOUND return httputils.NOT_FOUND
if not access.check("r", item): if not access.check("r", item):
@ -280,8 +293,8 @@ class ApplicationReportMixin:
if isinstance(item, storage.BaseCollection): if isinstance(item, storage.BaseCollection):
collection = item collection = item
else: else:
assert item.collection is not None
collection = item.collection collection = item.collection
headers = {"Content-Type": "text/xml; charset=%s" % self._encoding}
try: try:
status, xml_answer = xml_report( status, xml_answer = xml_report(
base_prefix, path, xml_content, collection, self._encoding, base_prefix, path, xml_content, collection, self._encoding,
@ -290,4 +303,5 @@ class ApplicationReportMixin:
logger.warning( logger.warning(
"Bad REPORT request on %r: %s", path, e, exc_info=True) "Bad REPORT request on %r: %s", path, e, exc_info=True)
return httputils.BAD_REQUEST return httputils.BAD_REQUEST
return status, headers, self._xml_response(xml_answer) headers = {"Content-Type": "text/xml; charset=%s" % self._encoding}
return status, headers, self._xml_response(xml_answer)

View File

@ -28,18 +28,23 @@ Take a look at the class ``BaseAuth`` if you want to implement your own.
""" """
from radicale import utils from typing import Sequence, Tuple, Union
INTERNAL_TYPES = ("none", "remote_user", "http_x_remote_user", "htpasswd") from radicale import config, types, utils
INTERNAL_TYPES: Sequence[str] = ("none", "remote_user", "http_x_remote_user",
"htpasswd")
def load(configuration): def load(configuration: "config.Configuration") -> "BaseAuth":
"""Load the authentication module chosen in configuration.""" """Load the authentication module chosen in configuration."""
return utils.load_plugin(INTERNAL_TYPES, "auth", "Auth", configuration) return utils.load_plugin(INTERNAL_TYPES, "auth", "Auth", BaseAuth,
configuration)
class BaseAuth: class BaseAuth:
def __init__(self, configuration):
def __init__(self, configuration: "config.Configuration") -> None:
"""Initialize BaseAuth. """Initialize BaseAuth.
``configuration`` see ``radicale.config`` module. ``configuration`` see ``radicale.config`` module.
@ -49,7 +54,8 @@ class BaseAuth:
""" """
self.configuration = configuration self.configuration = configuration
def get_external_login(self, environ): def get_external_login(self, environ: types.WSGIEnviron) -> Union[
Tuple[()], Tuple[str, str]]:
"""Optionally provide the login and password externally. """Optionally provide the login and password externally.
``environ`` a dict with the WSGI environment ``environ`` a dict with the WSGI environment
@ -61,7 +67,7 @@ class BaseAuth:
""" """
return () return ()
def login(self, login, password): def login(self, login: str, password: str) -> str:
"""Check credentials and map login to internal user """Check credentials and map login to internal user
``login`` the login name ``login`` the login name

View File

@ -49,18 +49,23 @@ When passlib[bcrypt] is installed:
import functools import functools
import hmac import hmac
from typing import Any
from passlib.hash import apr_md5_crypt from passlib.hash import apr_md5_crypt
from radicale import auth from radicale import auth, config
class Auth(auth.BaseAuth): class Auth(auth.BaseAuth):
def __init__(self, configuration):
_filename: str
_encoding: str
def __init__(self, configuration: config.Configuration) -> None:
super().__init__(configuration) super().__init__(configuration)
self._filename = configuration.get("auth", "htpasswd_filename") self._filename = configuration.get("auth", "htpasswd_filename")
self._encoding = self.configuration.get("encoding", "stock") self._encoding = configuration.get("encoding", "stock")
encryption = configuration.get("auth", "htpasswd_encryption") encryption: str = configuration.get("auth", "htpasswd_encryption")
if encryption == "plain": if encryption == "plain":
self._verify = self._plain self._verify = self._plain
@ -82,17 +87,17 @@ class Auth(auth.BaseAuth):
raise RuntimeError("The htpasswd encryption method %r is not " raise RuntimeError("The htpasswd encryption method %r is not "
"supported." % encryption) "supported." % encryption)
def _plain(self, hash_value, password): def _plain(self, hash_value: str, password: str) -> bool:
"""Check if ``hash_value`` and ``password`` match, plain method.""" """Check if ``hash_value`` and ``password`` match, plain method."""
return hmac.compare_digest(hash_value.encode(), password.encode()) return hmac.compare_digest(hash_value.encode(), password.encode())
def _bcrypt(self, bcrypt, hash_value, password): def _bcrypt(self, bcrypt: Any, hash_value: str, password: str) -> bool:
return bcrypt.verify(password, hash_value.strip()) return bcrypt.verify(password, hash_value.strip())
def _md5apr1(self, hash_value, password): def _md5apr1(self, hash_value: str, password: str) -> bool:
return apr_md5_crypt.verify(password, hash_value.strip()) return apr_md5_crypt.verify(password, hash_value.strip())
def login(self, login, password): def login(self, login: str, password: str) -> str:
"""Validate credentials. """Validate credentials.
Iterate through htpasswd credential file until login matches, extract Iterate through htpasswd credential file until login matches, extract

View File

@ -26,9 +26,14 @@ if the reverse proxy is not configured properly.
""" """
import radicale.auth.none as none from typing import Tuple, Union
from radicale import types
from radicale.auth import none
class Auth(none.Auth): class Auth(none.Auth):
def get_external_login(self, environ):
def get_external_login(self, environ: types.WSGIEnviron) -> Union[
Tuple[()], Tuple[str, str]]:
return environ.get("HTTP_X_REMOTE_USER", ""), "" return environ.get("HTTP_X_REMOTE_USER", ""), ""

View File

@ -26,5 +26,6 @@ from radicale import auth
class Auth(auth.BaseAuth): class Auth(auth.BaseAuth):
def login(self, login, password):
def login(self, login: str, password: str) -> str:
return login return login

View File

@ -25,9 +25,14 @@ It's intended for use with an external WSGI server.
""" """
import radicale.auth.none as none from typing import Tuple, Union
from radicale import types
from radicale.auth import none
class Auth(none.Auth): class Auth(none.Auth):
def get_external_login(self, environ):
def get_external_login(self, environ: types.WSGIEnviron
) -> Union[Tuple[()], Tuple[str, str]]:
return environ.get("REMOTE_USER", ""), "" return environ.get("REMOTE_USER", ""), ""

View File

@ -29,25 +29,27 @@ import contextlib
import math import math
import os import os
import string import string
import sys
from collections import OrderedDict from collections import OrderedDict
from configparser import RawConfigParser from configparser import RawConfigParser
from typing import Any, ClassVar from typing import (Any, Callable, ClassVar, Iterable, List, Optional,
Sequence, Tuple, TypeVar, Union)
from radicale import auth, rights, storage, web from radicale import auth, rights, storage, types, web
DEFAULT_CONFIG_PATH = os.pathsep.join([ DEFAULT_CONFIG_PATH: str = os.pathsep.join([
"?/etc/radicale/config", "?/etc/radicale/config",
"?~/.config/radicale/config"]) "?~/.config/radicale/config"])
def positive_int(value): def positive_int(value: Any) -> int:
value = int(value) value = int(value)
if value < 0: if value < 0:
raise ValueError("value is negative: %d" % value) raise ValueError("value is negative: %d" % value)
return value return value
def positive_float(value): def positive_float(value: Any) -> float:
value = float(value) value = float(value)
if not math.isfinite(value): if not math.isfinite(value):
raise ValueError("value is infinite") raise ValueError("value is infinite")
@ -58,22 +60,22 @@ def positive_float(value):
return value return value
def logging_level(value): def logging_level(value: Any) -> str:
if value not in ("debug", "info", "warning", "error", "critical"): if value not in ("debug", "info", "warning", "error", "critical"):
raise ValueError("unsupported level: %r" % value) raise ValueError("unsupported level: %r" % value)
return value return value
def filepath(value): def filepath(value: Any) -> str:
if not value: if not value:
return "" return ""
value = os.path.expanduser(value) value = os.path.expanduser(value)
if os.name == "nt": if sys.platform == "win32":
value = os.path.expandvars(value) value = os.path.expandvars(value)
return os.path.abspath(value) return os.path.abspath(value)
def list_of_ip_address(value): def list_of_ip_address(value: Any) -> List[Tuple[str, int]]:
def ip_address(value): def ip_address(value):
try: try:
address, port = value.rsplit(":", 1) address, port = value.rsplit(":", 1)
@ -83,25 +85,25 @@ def list_of_ip_address(value):
return [ip_address(s) for s in value.split(",")] return [ip_address(s) for s in value.split(",")]
def str_or_callable(value): def str_or_callable(value: Any) -> Union[str, Callable]:
if callable(value): if callable(value):
return value return value
return str(value) return str(value)
def unspecified_type(value): def unspecified_type(value: Any) -> Any:
return value return value
def _convert_to_bool(value): def _convert_to_bool(value: Any) -> bool:
if value.lower() not in RawConfigParser.BOOLEAN_STATES: if value.lower() not in RawConfigParser.BOOLEAN_STATES:
raise ValueError("not a boolean: %r" % value) raise ValueError("not a boolean: %r" % value)
return RawConfigParser.BOOLEAN_STATES[value.lower()] return RawConfigParser.BOOLEAN_STATES[value.lower()]
INTERNAL_OPTIONS = ("_allow_extra",) INTERNAL_OPTIONS: Sequence[str] = ("_allow_extra",)
# Default configuration # Default configuration
DEFAULT_CONFIG_SCHEMA = OrderedDict([ DEFAULT_CONFIG_SCHEMA: types.CONFIG_SCHEMA = OrderedDict([
("server", OrderedDict([ ("server", OrderedDict([
("hosts", { ("hosts", {
"value": "localhost:5232", "value": "localhost:5232",
@ -227,7 +229,8 @@ DEFAULT_CONFIG_SCHEMA = OrderedDict([
("_allow_extra", str)]))]) ("_allow_extra", str)]))])
def parse_compound_paths(*compound_paths): def parse_compound_paths(*compound_paths: Optional[str]
) -> List[Tuple[str, bool]]:
"""Parse a compound path and return the individual paths. """Parse a compound path and return the individual paths.
Paths in a compound path are joined by ``os.pathsep``. If a path starts Paths in a compound path are joined by ``os.pathsep``. If a path starts
with ``?`` the return value ``IGNORE_IF_MISSING`` is set. with ``?`` the return value ``IGNORE_IF_MISSING`` is set.
@ -253,7 +256,8 @@ def parse_compound_paths(*compound_paths):
return paths return paths
def load(paths=()): def load(paths: Optional[Iterable[Tuple[str, bool]]] = None
) -> "Configuration":
""" """
Create instance of ``Configuration`` for use with Create instance of ``Configuration`` for use with
``radicale.app.Application``. ``radicale.app.Application``.
@ -266,6 +270,8 @@ def load(paths=()):
The configuration can later be changed with ``Configuration.update()``. The configuration can later be changed with ``Configuration.update()``.
""" """
if paths is None:
paths = []
configuration = Configuration(DEFAULT_CONFIG_SCHEMA) configuration = Configuration(DEFAULT_CONFIG_SCHEMA)
for path, ignore_if_missing in paths: for path, ignore_if_missing in paths:
parser = RawConfigParser() parser = RawConfigParser()
@ -279,16 +285,24 @@ def load(paths=()):
config = {s: {o: parser[s][o] for o in parser.options(s)} config = {s: {o: parser[s][o] for o in parser.options(s)}
for s in parser.sections()} for s in parser.sections()}
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError("Failed to load %s: %s" % (config_source, e)
"Failed to load %s: %s" % (config_source, e)) from e ) from e
configuration.update(config, config_source) configuration.update(config, config_source)
return configuration return configuration
class Configuration: _Self = TypeVar("_Self", bound="Configuration")
SOURCE_MISSING: ClassVar[Any] = {}
def __init__(self, schema):
class Configuration:
SOURCE_MISSING: ClassVar[types.CONFIG] = {}
_schema: types.CONFIG_SCHEMA
_values: types.MUTABLE_CONFIG
_configs: List[Tuple[types.CONFIG, str, bool]]
def __init__(self, schema: types.CONFIG_SCHEMA) -> None:
"""Initialize configuration. """Initialize configuration.
``schema`` a dict that describes the configuration format. ``schema`` a dict that describes the configuration format.
@ -309,7 +323,8 @@ class Configuration:
for section in self._schema} for section in self._schema}
self.update(default, "default config", privileged=True) self.update(default, "default config", privileged=True)
def update(self, config, source=None, privileged=False): def update(self, config: types.CONFIG, source: Optional[str] = None,
privileged: bool = False) -> None:
"""Update the configuration. """Update the configuration.
``config`` a dict of the format {SECTION: {OPTION: VALUE, ...}, ...}. ``config`` a dict of the format {SECTION: {OPTION: VALUE, ...}, ...}.
@ -323,8 +338,9 @@ class Configuration:
``privileged`` allows updating sections and options starting with "_". ``privileged`` allows updating sections and options starting with "_".
""" """
source = source or "unspecified config" if source is None:
new_values = {} source = "unspecified config"
new_values: types.MUTABLE_CONFIG = {}
for section in config: for section in config:
if (section not in self._schema or if (section not in self._schema or
section.startswith("_") and not privileged): section.startswith("_") and not privileged):
@ -363,40 +379,41 @@ class Configuration:
self._values[section] = self._values.get(section, {}) self._values[section] = self._values.get(section, {})
self._values[section].update(new_values[section]) self._values[section].update(new_values[section])
def get(self, section, option): def get(self, section: str, option: str) -> Any:
"""Get the value of ``option`` in ``section``.""" """Get the value of ``option`` in ``section``."""
with contextlib.suppress(KeyError): with contextlib.suppress(KeyError):
return self._values[section][option] return self._values[section][option]
raise KeyError(section, option) raise KeyError(section, option)
def get_raw(self, section, option): def get_raw(self, section: str, option: str) -> Any:
"""Get the raw value of ``option`` in ``section``.""" """Get the raw value of ``option`` in ``section``."""
for config, _, _ in reversed(self._configs): for config, _, _ in reversed(self._configs):
if option in config.get(section, {}): if option in config.get(section, {}):
return config[section][option] return config[section][option]
raise KeyError(section, option) raise KeyError(section, option)
def get_source(self, section, option): def get_source(self, section: str, option: str) -> str:
"""Get the source that provides ``option`` in ``section``.""" """Get the source that provides ``option`` in ``section``."""
for config, source, _ in reversed(self._configs): for config, source, _ in reversed(self._configs):
if option in config.get(section, {}): if option in config.get(section, {}):
return source return source
raise KeyError(section, option) raise KeyError(section, option)
def sections(self): def sections(self) -> List[str]:
"""List all sections.""" """List all sections."""
return self._values.keys() return list(self._values.keys())
def options(self, section): def options(self, section: str) -> List[str]:
"""List all options in ``section``""" """List all options in ``section``"""
return self._values[section].keys() return list(self._values[section].keys())
def sources(self): def sources(self) -> List[Tuple[str, bool]]:
"""List all config sources.""" """List all config sources."""
return [(source, config is self.SOURCE_MISSING) for return [(source, config is self.SOURCE_MISSING) for
config, source, _ in self._configs] config, source, _ in self._configs]
def copy(self, plugin_schema=None): def copy(self: _Self, plugin_schema: Optional[types.CONFIG_SCHEMA] = None
) -> _Self:
"""Create a copy of the configuration """Create a copy of the configuration
``plugin_schema`` is a optional dict that contains additional options ``plugin_schema`` is a optional dict that contains additional options
@ -406,20 +423,23 @@ class Configuration:
if plugin_schema is None: if plugin_schema is None:
schema = self._schema schema = self._schema
else: else:
schema = self._schema.copy() new_schema = dict(self._schema)
for section, options in plugin_schema.items(): for section, options in plugin_schema.items():
if (section not in schema or "type" not in schema[section] or if (section not in new_schema or
"internal" not in schema[section]["type"]): "type" not in new_schema[section] or
"internal" not in new_schema[section]["type"]):
raise ValueError("not a plugin section: %r" % section) raise ValueError("not a plugin section: %r" % section)
schema[section] = schema[section].copy() new_section = dict(new_schema[section])
schema[section]["type"] = schema[section]["type"].copy() new_type = dict(new_section["type"])
schema[section]["type"]["internal"] = [ new_type["internal"] = (self.get(section, "type"),)
self.get(section, "type")] new_section["type"] = new_type
for option, value in options.items(): for option, value in options.items():
if option in schema[section]: if option in new_section:
raise ValueError("option already exists in %r: %r" % ( raise ValueError("option already exists in %r: %r" %
section, option)) (section, option))
schema[section][option] = value new_section[option] = value
new_schema[section] = new_section
schema = new_schema
copy = type(self)(schema) copy = type(self)(schema)
for config, source, privileged in self._configs: for config, source, privileged in self._configs:
copy.update(config, source, privileged) copy.update(config, source, privileged)

View File

@ -22,53 +22,57 @@ Helper functions for HTTP.
""" """
import contextlib
from http import client from http import client
from typing import List, cast
from radicale import config, types
from radicale.log import logger from radicale.log import logger
NOT_ALLOWED = ( NOT_ALLOWED: types.WSGIResponse = (
client.FORBIDDEN, (("Content-Type", "text/plain"),), client.FORBIDDEN, (("Content-Type", "text/plain"),),
"Access to the requested resource forbidden.") "Access to the requested resource forbidden.")
FORBIDDEN = ( FORBIDDEN: types.WSGIResponse = (
client.FORBIDDEN, (("Content-Type", "text/plain"),), client.FORBIDDEN, (("Content-Type", "text/plain"),),
"Action on the requested resource refused.") "Action on the requested resource refused.")
BAD_REQUEST = ( BAD_REQUEST: types.WSGIResponse = (
client.BAD_REQUEST, (("Content-Type", "text/plain"),), "Bad Request") client.BAD_REQUEST, (("Content-Type", "text/plain"),), "Bad Request")
NOT_FOUND = ( NOT_FOUND: types.WSGIResponse = (
client.NOT_FOUND, (("Content-Type", "text/plain"),), client.NOT_FOUND, (("Content-Type", "text/plain"),),
"The requested resource could not be found.") "The requested resource could not be found.")
CONFLICT = ( CONFLICT: types.WSGIResponse = (
client.CONFLICT, (("Content-Type", "text/plain"),), client.CONFLICT, (("Content-Type", "text/plain"),),
"Conflict in the request.") "Conflict in the request.")
METHOD_NOT_ALLOWED = ( METHOD_NOT_ALLOWED: types.WSGIResponse = (
client.METHOD_NOT_ALLOWED, (("Content-Type", "text/plain"),), client.METHOD_NOT_ALLOWED, (("Content-Type", "text/plain"),),
"The method is not allowed on the requested resource.") "The method is not allowed on the requested resource.")
PRECONDITION_FAILED = ( PRECONDITION_FAILED: types.WSGIResponse = (
client.PRECONDITION_FAILED, client.PRECONDITION_FAILED,
(("Content-Type", "text/plain"),), "Precondition failed.") (("Content-Type", "text/plain"),), "Precondition failed.")
REQUEST_TIMEOUT = ( REQUEST_TIMEOUT: types.WSGIResponse = (
client.REQUEST_TIMEOUT, (("Content-Type", "text/plain"),), client.REQUEST_TIMEOUT, (("Content-Type", "text/plain"),),
"Connection timed out.") "Connection timed out.")
REQUEST_ENTITY_TOO_LARGE = ( REQUEST_ENTITY_TOO_LARGE: types.WSGIResponse = (
client.REQUEST_ENTITY_TOO_LARGE, (("Content-Type", "text/plain"),), client.REQUEST_ENTITY_TOO_LARGE, (("Content-Type", "text/plain"),),
"Request body too large.") "Request body too large.")
REMOTE_DESTINATION = ( REMOTE_DESTINATION: types.WSGIResponse = (
client.BAD_GATEWAY, (("Content-Type", "text/plain"),), client.BAD_GATEWAY, (("Content-Type", "text/plain"),),
"Remote destination not supported.") "Remote destination not supported.")
DIRECTORY_LISTING = ( DIRECTORY_LISTING: types.WSGIResponse = (
client.FORBIDDEN, (("Content-Type", "text/plain"),), client.FORBIDDEN, (("Content-Type", "text/plain"),),
"Directory listings are not supported.") "Directory listings are not supported.")
INTERNAL_SERVER_ERROR = ( INTERNAL_SERVER_ERROR: types.WSGIResponse = (
client.INTERNAL_SERVER_ERROR, (("Content-Type", "text/plain"),), client.INTERNAL_SERVER_ERROR, (("Content-Type", "text/plain"),),
"A server error occurred. Please contact the administrator.") "A server error occurred. Please contact the administrator.")
DAV_HEADERS = "1, 2, 3, calendar-access, addressbook, extended-mkcol" DAV_HEADERS: str = "1, 2, 3, calendar-access, addressbook, extended-mkcol"
def decode_request(configuration, environ, text): def decode_request(configuration: "config.Configuration",
environ: types.WSGIEnviron, text: bytes) -> str:
"""Try to magically decode ``text`` according to given ``environ``.""" """Try to magically decode ``text`` according to given ``environ``."""
# List of charsets to try # List of charsets to try
charsets = [] charsets: List[str] = []
# First append content charset given in the request # First append content charset given in the request
content_type = environ.get("CONTENT_TYPE") content_type = environ.get("CONTENT_TYPE")
@ -76,7 +80,7 @@ def decode_request(configuration, environ, text):
charsets.append( charsets.append(
content_type.split("charset=")[1].split(";")[0].strip()) content_type.split("charset=")[1].split(";")[0].strip())
# Then append default Radicale charset # Then append default Radicale charset
charsets.append(configuration.get("encoding", "request")) charsets.append(cast(str, configuration.get("encoding", "request")))
# Then append various fallbacks # Then append various fallbacks
charsets.append("utf-8") charsets.append("utf-8")
charsets.append("iso8859-1") charsets.append("iso8859-1")
@ -87,15 +91,14 @@ def decode_request(configuration, environ, text):
# Try to decode # Try to decode
for charset in charsets: for charset in charsets:
try: with contextlib.suppress(UnicodeDecodeError):
return text.decode(charset) return text.decode(charset)
except UnicodeDecodeError:
pass
raise UnicodeDecodeError("decode_request", text, 0, len(text), raise UnicodeDecodeError("decode_request", text, 0, len(text),
"all codecs failed [%s]" % ", ".join(charsets)) "all codecs failed [%s]" % ", ".join(charsets))
def read_raw_request_body(configuration, environ): def read_raw_request_body(configuration: "config.Configuration",
environ: types.WSGIEnviron) -> bytes:
content_length = int(environ.get("CONTENT_LENGTH") or 0) content_length = int(environ.get("CONTENT_LENGTH") or 0)
if not content_length: if not content_length:
return b"" return b""
@ -105,8 +108,9 @@ def read_raw_request_body(configuration, environ):
return content return content
def read_request_body(configuration, environ): def read_request_body(configuration: "config.Configuration",
content = decode_request( environ: types.WSGIEnviron) -> str:
configuration, environ, read_raw_request_body(configuration, environ)) content = decode_request(configuration, environ,
read_raw_request_body(configuration, environ))
logger.debug("Request content:\n%s", content) logger.debug("Request content:\n%s", content)
return content return content

View File

@ -27,27 +27,35 @@ import binascii
import math import math
import os import os
import sys import sys
from datetime import timedelta from datetime import datetime, timedelta
from hashlib import sha256 from hashlib import sha256
from typing import (Any, Callable, List, MutableMapping, Optional, Sequence,
Tuple)
import vobject import vobject
from radicale import storage # noqa:F401
from radicale import pathutils from radicale import pathutils
from radicale.item import filter as radicale_filter from radicale.item import filter as radicale_filter
from radicale.log import logger from radicale.log import logger
def predict_tag_of_parent_collection(vobject_items): def predict_tag_of_parent_collection(
vobject_items: Sequence[vobject.base.Component]) -> Optional[str]:
"""Returns the predicted tag or `None`"""
if len(vobject_items) != 1: if len(vobject_items) != 1:
return "" return None
if vobject_items[0].name == "VCALENDAR": if vobject_items[0].name == "VCALENDAR":
return "VCALENDAR" return "VCALENDAR"
if vobject_items[0].name in ("VCARD", "VLIST"): if vobject_items[0].name in ("VCARD", "VLIST"):
return "VADDRESSBOOK" return "VADDRESSBOOK"
return "" return None
def predict_tag_of_whole_collection(vobject_items, fallback_tag=None): def predict_tag_of_whole_collection(
vobject_items: Sequence[vobject.base.Component],
fallback_tag: Optional[str] = None) -> Optional[str]:
"""Returns the predicted tag or `fallback_tag`"""
if vobject_items and vobject_items[0].name == "VCALENDAR": if vobject_items and vobject_items[0].name == "VCALENDAR":
return "VCALENDAR" return "VCALENDAR"
if vobject_items and vobject_items[0].name in ("VCARD", "VLIST"): if vobject_items and vobject_items[0].name in ("VCARD", "VLIST"):
@ -58,9 +66,13 @@ def predict_tag_of_whole_collection(vobject_items, fallback_tag=None):
return fallback_tag return fallback_tag
def check_and_sanitize_items(vobject_items, is_collection=False, tag=None): def check_and_sanitize_items(
vobject_items: List[vobject.base.Component],
is_collection: bool = False, tag: str = "") -> None:
"""Check vobject items for common errors and add missing UIDs. """Check vobject items for common errors and add missing UIDs.
Modifies the list `vobject_items`.
``is_collection`` indicates that vobject_item contains unrelated ``is_collection`` indicates that vobject_item contains unrelated
components. components.
@ -169,9 +181,14 @@ def check_and_sanitize_items(vobject_items, is_collection=False, tag=None):
(i.name, repr(tag) if tag else "generic")) (i.name, repr(tag) if tag else "generic"))
def check_and_sanitize_props(props): def check_and_sanitize_props(props: MutableMapping[Any, Any]
"""Check collection properties for common errors.""" ) -> MutableMapping[str, str]:
for k, v in props.copy().items(): # Make copy to be able to delete items """Check collection properties for common errors.
Modifies the dict `props`.
"""
for k, v in list(props.items()): # Make copy to be able to delete items
if not isinstance(k, str): if not isinstance(k, str):
raise ValueError("Key must be %r not %r: %r" % ( raise ValueError("Key must be %r not %r: %r" % (
str.__name__, type(k).__name__, k)) str.__name__, type(k).__name__, k))
@ -182,14 +199,13 @@ def check_and_sanitize_props(props):
raise ValueError("Value of %r must be %r not %r: %r" % ( raise ValueError("Value of %r must be %r not %r: %r" % (
k, str.__name__, type(v).__name__, v)) k, str.__name__, type(v).__name__, v))
if k == "tag": if k == "tag":
if not v: if v not in ("", "VCALENDAR", "VADDRESSBOOK"):
del props[k]
continue
if v not in ("VCALENDAR", "VADDRESSBOOK"):
raise ValueError("Unsupported collection tag: %r" % v) raise ValueError("Unsupported collection tag: %r" % v)
return props
def find_available_uid(exists_fn, suffix=""): def find_available_uid(exists_fn: Callable[[str], bool], suffix: str = ""
) -> str:
"""Generate a pseudo-random UID""" """Generate a pseudo-random UID"""
# Prevent infinite loop # Prevent infinite loop
for _ in range(1000): for _ in range(1000):
@ -202,7 +218,7 @@ def find_available_uid(exists_fn, suffix=""):
raise RuntimeError("No unique random sequence found") raise RuntimeError("No unique random sequence found")
def get_etag(text): def get_etag(text: str) -> str:
"""Etag from collection or item. """Etag from collection or item.
Encoded as quoted-string (see RFC 2616). Encoded as quoted-string (see RFC 2616).
@ -213,13 +229,13 @@ def get_etag(text):
return '"%s"' % etag.hexdigest() return '"%s"' % etag.hexdigest()
def get_uid(vobject_component): def get_uid(vobject_component: vobject.base.Component) -> str:
"""UID value of an item if defined.""" """UID value of an item if defined."""
return (vobject_component.uid.value return (vobject_component.uid.value or ""
if hasattr(vobject_component, "uid") else None) if hasattr(vobject_component, "uid") else "")
def get_uid_from_object(vobject_item): def get_uid_from_object(vobject_item: vobject.base.Component) -> str:
"""UID value of an calendar/addressbook object.""" """UID value of an calendar/addressbook object."""
if vobject_item.name == "VCALENDAR": if vobject_item.name == "VCALENDAR":
if hasattr(vobject_item, "vevent"): if hasattr(vobject_item, "vevent"):
@ -230,10 +246,10 @@ def get_uid_from_object(vobject_item):
return get_uid(vobject_item.vtodo) return get_uid(vobject_item.vtodo)
elif vobject_item.name == "VCARD": elif vobject_item.name == "VCARD":
return get_uid(vobject_item) return get_uid(vobject_item)
return None return ""
def find_tag(vobject_item): def find_tag(vobject_item: vobject.base.Component) -> str:
"""Find component name from ``vobject_item``.""" """Find component name from ``vobject_item``."""
if vobject_item.name == "VCALENDAR": if vobject_item.name == "VCALENDAR":
for component in vobject_item.components(): for component in vobject_item.components():
@ -242,22 +258,24 @@ def find_tag(vobject_item):
return "" return ""
def find_tag_and_time_range(vobject_item): def find_time_range(vobject_item: vobject.base.Component, tag: str
"""Find component name and enclosing time range from ``vobject item``. ) -> Tuple[int, int]:
"""Find enclosing time range from ``vobject item``.
Returns a tuple (``tag``, ``start``, ``end``) where ``tag`` is a string ``tag`` must be set to the return value of ``find_tag``.
and ``start`` and ``end`` are POSIX timestamps (as int).
Returns a tuple (``start``, ``end``) where ``start`` and ``end`` are
POSIX timestamps.
This is intened to be used for matching against simplified prefilters. This is intened to be used for matching against simplified prefilters.
""" """
tag = find_tag(vobject_item)
if not tag: if not tag:
return ( return radicale_filter.TIMESTAMP_MIN, radicale_filter.TIMESTAMP_MAX
tag, radicale_filter.TIMESTAMP_MIN, radicale_filter.TIMESTAMP_MAX)
start = end = None start = end = None
def range_fn(range_start, range_end, is_recurrence): def range_fn(range_start: datetime, range_end: datetime,
is_recurrence: bool) -> bool:
nonlocal start, end nonlocal start, end
if start is None or range_start < start: if start is None or range_start < start:
start = range_start start = range_start
@ -265,7 +283,7 @@ def find_tag_and_time_range(vobject_item):
end = range_end end = range_end
return False return False
def infinity_fn(range_start): def infinity_fn(range_start: datetime) -> bool:
nonlocal start, end nonlocal start, end
if start is None or range_start < start: if start is None or range_start < start:
start = range_start start = range_start
@ -278,7 +296,7 @@ def find_tag_and_time_range(vobject_item):
if end is None: if end is None:
end = radicale_filter.DATETIME_MAX end = radicale_filter.DATETIME_MAX
try: try:
return tag, math.floor(start.timestamp()), math.ceil(end.timestamp()) return math.floor(start.timestamp()), math.ceil(end.timestamp())
except ValueError as e: except ValueError as e:
if str(e) == ("offset must be a timedelta representing a whole " if str(e) == ("offset must be a timedelta representing a whole "
"number of minutes") and sys.version_info < (3, 6): "number of minutes") and sys.version_info < (3, 6):
@ -289,10 +307,31 @@ def find_tag_and_time_range(vobject_item):
class Item: class Item:
"""Class for address book and calendar entries.""" """Class for address book and calendar entries."""
def __init__(self, collection_path=None, collection=None, collection: Optional["storage.BaseCollection"]
vobject_item=None, href=None, last_modified=None, text=None, href: Optional[str]
etag=None, uid=None, name=None, component_name=None, last_modified: Optional[str]
time_range=None):
_collection_path: str
_text: Optional[str]
_vobject_item: Optional[vobject.base.Component]
_etag: Optional[str]
_uid: Optional[str]
_name: Optional[str]
_component_name: Optional[str]
_time_range: Optional[Tuple[int, int]]
def __init__(self,
collection_path: Optional[str] = None,
collection: Optional["storage.BaseCollection"] = None,
vobject_item: Optional[vobject.base.Component] = None,
href: Optional[str] = None,
last_modified: Optional[str] = None,
text: Optional[str] = None,
etag: Optional[str] = None,
uid: Optional[str] = None,
name: Optional[str] = None,
component_name: Optional[str] = None,
time_range: Optional[Tuple[int, int]] = None):
"""Initialize an item. """Initialize an item.
``collection_path`` the path of the parent collection (optional if ``collection_path`` the path of the parent collection (optional if
@ -318,8 +357,7 @@ class Item:
``component_name`` the name of the primary component (optional). ``component_name`` the name of the primary component (optional).
See ``find_tag``. See ``find_tag``.
``time_range`` the enclosing time range. ``time_range`` the enclosing time range. See ``find_time_range``.
See ``find_tag_and_time_range``.
""" """
if text is None and vobject_item is None: if text is None and vobject_item is None:
@ -344,7 +382,7 @@ class Item:
self._component_name = component_name self._component_name = component_name
self._time_range = time_range self._time_range = time_range
def serialize(self): def serialize(self) -> str:
if self._text is None: if self._text is None:
try: try:
self._text = self.vobject_item.serialize() self._text = self.vobject_item.serialize()
@ -366,38 +404,38 @@ class Item:
return self._vobject_item return self._vobject_item
@property @property
def etag(self): def etag(self) -> str:
"""Encoded as quoted-string (see RFC 2616).""" """Encoded as quoted-string (see RFC 2616)."""
if self._etag is None: if self._etag is None:
self._etag = get_etag(self.serialize()) self._etag = get_etag(self.serialize())
return self._etag return self._etag
@property @property
def uid(self): def uid(self) -> str:
if self._uid is None: if self._uid is None:
self._uid = get_uid_from_object(self.vobject_item) self._uid = get_uid_from_object(self.vobject_item)
return self._uid return self._uid
@property @property
def name(self): def name(self) -> str:
if self._name is None: if self._name is None:
self._name = self.vobject_item.name or "" self._name = self.vobject_item.name or ""
return self._name return self._name
@property @property
def component_name(self): def component_name(self) -> str:
if self._component_name is not None: if self._component_name is None:
return self._component_name self._component_name = find_tag(self.vobject_item)
return find_tag(self.vobject_item) return self._component_name
@property @property
def time_range(self): def time_range(self) -> Tuple[int, int]:
if self._time_range is None: if self._time_range is None:
self._component_name, *self._time_range = ( self._time_range = find_time_range(
find_tag_and_time_range(self.vobject_item)) self.vobject_item, self.component_name)
return self._time_range return self._time_range
def prepare(self): def prepare(self) -> None:
"""Fill cache with values.""" """Fill cache with values."""
orig_vobject_item = self._vobject_item orig_vobject_item = self._vobject_item
self.serialize() self.serialize()

View File

@ -19,35 +19,40 @@
import math import math
import xml.etree.ElementTree as ET
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
from itertools import chain from itertools import chain
from typing import (Callable, Iterable, Iterator, List, Optional, Sequence,
Tuple)
from radicale import xmlutils import vobject
from radicale import item, xmlutils
from radicale.log import logger from radicale.log import logger
DAY = timedelta(days=1) DAY: timedelta = timedelta(days=1)
SECOND = timedelta(seconds=1) SECOND: timedelta = timedelta(seconds=1)
DATETIME_MIN = datetime.min.replace(tzinfo=timezone.utc) DATETIME_MIN: datetime = datetime.min.replace(tzinfo=timezone.utc)
DATETIME_MAX = datetime.max.replace(tzinfo=timezone.utc) DATETIME_MAX: datetime = datetime.max.replace(tzinfo=timezone.utc)
TIMESTAMP_MIN = math.floor(DATETIME_MIN.timestamp()) TIMESTAMP_MIN: int = math.floor(DATETIME_MIN.timestamp())
TIMESTAMP_MAX = math.ceil(DATETIME_MAX.timestamp()) TIMESTAMP_MAX: int = math.ceil(DATETIME_MAX.timestamp())
def date_to_datetime(date_): def date_to_datetime(d: date) -> datetime:
"""Transform a date to a UTC datetime. """Transform any date to a UTC datetime.
If date_ is a datetime without timezone, return as UTC datetime. If date_ If ``d`` is a datetime without timezone, return as UTC datetime. If ``d``
is already a datetime with timezone, return as is. is already a datetime with timezone, return as is.
""" """
if not isinstance(date_, datetime): if not isinstance(d, datetime):
date_ = datetime.combine(date_, datetime.min.time()) d = datetime.combine(d, datetime.min.time())
if not date_.tzinfo: if not d.tzinfo:
date_ = date_.replace(tzinfo=timezone.utc) d = d.replace(tzinfo=timezone.utc)
return date_ return d
def comp_match(item, filter_, level=0): def comp_match(item: "item.Item", filter_: ET.Element, level: int = 0) -> bool:
"""Check whether the ``item`` matches the comp ``filter_``. """Check whether the ``item`` matches the comp ``filter_``.
If ``level`` is ``0``, the filter is applied on the If ``level`` is ``0``, the filter is applied on the
@ -70,7 +75,7 @@ def comp_match(item, filter_, level=0):
return True return True
if not tag: if not tag:
return False return False
name = filter_.get("name").upper() name = filter_.get("name", "").upper()
if len(filter_) == 0: if len(filter_) == 0:
# Point #1 of rfc4791-9.7.1 # Point #1 of rfc4791-9.7.1
return name == tag return name == tag
@ -104,13 +109,14 @@ def comp_match(item, filter_, level=0):
return True return True
def prop_match(vobject_item, filter_, ns): def prop_match(vobject_item: vobject.base.Component,
filter_: ET.Element, ns: str) -> bool:
"""Check whether the ``item`` matches the prop ``filter_``. """Check whether the ``item`` matches the prop ``filter_``.
See rfc4791-9.7.2 and rfc6352-10.5.1. See rfc4791-9.7.2 and rfc6352-10.5.1.
""" """
name = filter_.get("name").lower() name = filter_.get("name", "").lower()
if len(filter_) == 0: if len(filter_) == 0:
# Point #1 of rfc4791-9.7.2 # Point #1 of rfc4791-9.7.2
return name in vobject_item.contents return name in vobject_item.contents
@ -136,20 +142,21 @@ def prop_match(vobject_item, filter_, ns):
return True return True
def time_range_match(vobject_item, filter_, child_name): def time_range_match(vobject_item: vobject.base.Component,
filter_: ET.Element, child_name: str) -> bool:
"""Check whether the component/property ``child_name`` of """Check whether the component/property ``child_name`` of
``vobject_item`` matches the time-range ``filter_``.""" ``vobject_item`` matches the time-range ``filter_``."""
start = filter_.get("start") start_text = filter_.get("start")
end = filter_.get("end") end_text = filter_.get("end")
if not start and not end: if not start_text and not end_text:
return False return False
if start: if start_text:
start = datetime.strptime(start, "%Y%m%dT%H%M%SZ") start = datetime.strptime(start_text, "%Y%m%dT%H%M%SZ")
else: else:
start = datetime.min start = datetime.min
if end: if end_text:
end = datetime.strptime(end, "%Y%m%dT%H%M%SZ") end = datetime.strptime(end_text, "%Y%m%dT%H%M%SZ")
else: else:
end = datetime.max end = datetime.max
start = start.replace(tzinfo=timezone.utc) start = start.replace(tzinfo=timezone.utc)
@ -157,7 +164,8 @@ def time_range_match(vobject_item, filter_, child_name):
matched = False matched = False
def range_fn(range_start, range_end, is_recurrence): def range_fn(range_start: datetime, range_end: datetime,
is_recurrence: bool) -> bool:
nonlocal matched nonlocal matched
if start < range_end and range_start < end: if start < range_end and range_start < end:
matched = True matched = True
@ -166,14 +174,16 @@ def time_range_match(vobject_item, filter_, child_name):
return True return True
return False return False
def infinity_fn(start): def infinity_fn(start: datetime) -> bool:
return False return False
visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn) visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn)
return matched return matched
def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn): def visit_time_ranges(vobject_item: vobject.base.Component, child_name: str,
range_fn: Callable[[datetime, datetime, bool], bool],
infinity_fn: Callable[[datetime], bool]) -> None:
"""Visit all time ranges in the component/property ``child_name`` of """Visit all time ranges in the component/property ``child_name`` of
`vobject_item`` with visitors ``range_fn`` and ``infinity_fn``. `vobject_item`` with visitors ``range_fn`` and ``infinity_fn``.
@ -194,7 +204,8 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
# recurrences too. This is not respected and client don't seem to bother # recurrences too. This is not respected and client don't seem to bother
# either. # either.
def getrruleset(child, ignore=()): def getrruleset(child: vobject.base.Component, ignore: Sequence[date]
) -> Tuple[Iterable[date], bool]:
if (hasattr(child, "rrule") and if (hasattr(child, "rrule") and
";UNTIL=" not in child.rrule.value.upper() and ";UNTIL=" not in child.rrule.value.upper() and
";COUNT=" not in child.rrule.value.upper()): ";COUNT=" not in child.rrule.value.upper()):
@ -207,7 +218,8 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
return filter(lambda dtstart: dtstart not in ignore, return filter(lambda dtstart: dtstart not in ignore,
child.getrruleset(addRDate=True)), False child.getrruleset(addRDate=True)), False
def get_children(components): def get_children(components: Iterable[vobject.base.Component]) -> Iterator[
Tuple[vobject.base.Component, bool, List[date]]]:
main = None main = None
recurrences = [] recurrences = []
for comp in components: for comp in components:
@ -216,7 +228,7 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
if comp.rruleset: if comp.rruleset:
# Prevent possible infinite loop # Prevent possible infinite loop
raise ValueError("Overwritten recurrence with RRULESET") raise ValueError("Overwritten recurrence with RRULESET")
yield comp, True, () yield comp, True, []
else: else:
if main is not None: if main is not None:
raise ValueError("Multiple main components") raise ValueError("Multiple main components")
@ -418,7 +430,9 @@ def visit_time_ranges(vobject_item, child_name, range_fn, infinity_fn):
range_fn(child, child + DAY, False) range_fn(child, child + DAY, False)
def text_match(vobject_item, filter_, child_name, ns, attrib_name=None): def text_match(vobject_item: vobject.base.Component,
filter_: ET.Element, child_name: str, ns: str,
attrib_name: Optional[str] = None) -> bool:
"""Check whether the ``item`` matches the text-match ``filter_``. """Check whether the ``item`` matches the text-match ``filter_``.
See rfc4791-9.7.5. See rfc4791-9.7.5.
@ -432,7 +446,7 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
if ns == "CR": if ns == "CR":
match_type = filter_.get("match-type", match_type) match_type = filter_.get("match-type", match_type)
def match(value): def match(value: str) -> bool:
value = value.lower() value = value.lower()
if match_type == "equals": if match_type == "equals":
return value == text return value == text
@ -445,7 +459,7 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
raise ValueError("Unexpected text-match match-type: %r" % match_type) raise ValueError("Unexpected text-match match-type: %r" % match_type)
children = getattr(vobject_item, "%s_list" % child_name, []) children = getattr(vobject_item, "%s_list" % child_name, [])
if attrib_name: if attrib_name is not None:
condition = any( condition = any(
match(attrib) for child in children match(attrib) for child in children
for attrib in child.params.get(attrib_name, [])) for attrib in child.params.get(attrib_name, []))
@ -456,13 +470,14 @@ def text_match(vobject_item, filter_, child_name, ns, attrib_name=None):
return condition return condition
def param_filter_match(vobject_item, filter_, parent_name, ns): def param_filter_match(vobject_item: vobject.base.Component,
filter_: ET.Element, parent_name: str, ns: str) -> bool:
"""Check whether the ``item`` matches the param-filter ``filter_``. """Check whether the ``item`` matches the param-filter ``filter_``.
See rfc4791-9.7.3. See rfc4791-9.7.3.
""" """
name = filter_.get("name").upper() name = filter_.get("name", "").upper()
children = getattr(vobject_item, "%s_list" % parent_name, []) children = getattr(vobject_item, "%s_list" % parent_name, [])
condition = any(name in child.params for child in children) condition = any(name in child.params for child in children)
if len(filter_) > 0: if len(filter_) > 0:
@ -474,7 +489,8 @@ def param_filter_match(vobject_item, filter_, parent_name, ns):
return condition return condition
def simplify_prefilters(filters, collection_tag="VCALENDAR"): def simplify_prefilters(filters: Iterable[ET.Element], collection_tag: str
) -> Tuple[Optional[str], int, int, bool]:
"""Creates a simplified condition from ``filters``. """Creates a simplified condition from ``filters``.
Returns a tuple (``tag``, ``start``, ``end``, ``simple``) where ``tag`` is Returns a tuple (``tag``, ``start``, ``end``, ``simple``) where ``tag`` is
@ -483,14 +499,14 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
and the simplified condition are identical. and the simplified condition are identical.
""" """
flat_filters = tuple(chain.from_iterable(filters)) flat_filters = list(chain.from_iterable(filters))
simple = len(flat_filters) <= 1 simple = len(flat_filters) <= 1
for col_filter in flat_filters: for col_filter in flat_filters:
if collection_tag != "VCALENDAR": if collection_tag != "VCALENDAR":
simple = False simple = False
break break
if (col_filter.tag != xmlutils.make_clark("C:comp-filter") or if (col_filter.tag != xmlutils.make_clark("C:comp-filter") or
col_filter.get("name").upper() != "VCALENDAR"): col_filter.get("name", "").upper() != "VCALENDAR"):
simple = False simple = False
continue continue
simple &= len(col_filter) <= 1 simple &= len(col_filter) <= 1
@ -498,7 +514,7 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
if comp_filter.tag != xmlutils.make_clark("C:comp-filter"): if comp_filter.tag != xmlutils.make_clark("C:comp-filter"):
simple = False simple = False
continue continue
tag = comp_filter.get("name").upper() tag = comp_filter.get("name", "").upper()
if comp_filter.find( if comp_filter.find(
xmlutils.make_clark("C:is-not-defined")) is not None: xmlutils.make_clark("C:is-not-defined")) is not None:
simple = False simple = False
@ -511,17 +527,17 @@ def simplify_prefilters(filters, collection_tag="VCALENDAR"):
if time_filter.tag != xmlutils.make_clark("C:time-range"): if time_filter.tag != xmlutils.make_clark("C:time-range"):
simple = False simple = False
continue continue
start = time_filter.get("start") start_text = time_filter.get("start")
end = time_filter.get("end") end_text = time_filter.get("end")
if start: if start_text:
start = math.floor(datetime.strptime( start = math.floor(datetime.strptime(
start, "%Y%m%dT%H%M%SZ").replace( start_text, "%Y%m%dT%H%M%SZ").replace(
tzinfo=timezone.utc).timestamp()) tzinfo=timezone.utc).timestamp())
else: else:
start = TIMESTAMP_MIN start = TIMESTAMP_MIN
if end: if end_text:
end = math.ceil(datetime.strptime( end = math.ceil(datetime.strptime(
end, "%Y%m%dT%H%M%SZ").replace( end_text, "%Y%m%dT%H%M%SZ").replace(
tzinfo=timezone.utc).timestamp()) tzinfo=timezone.utc).timestamp())
else: else:
end = TIMESTAMP_MAX end = TIMESTAMP_MAX

View File

@ -25,42 +25,46 @@ Log messages are sent to the first available target of:
""" """
import contextlib
import logging import logging
import os import os
import sys import sys
import threading import threading
from typing import Any, Callable, ClassVar, Dict, Iterator, Union
LOGGER_NAME = "radicale" from radicale import types
LOGGER_FORMAT = "[%(asctime)s] [%(ident)s] [%(levelname)s] %(message)s"
DATE_FORMAT = "%Y-%m-%d %H:%M:%S %z"
logger = logging.getLogger(LOGGER_NAME) LOGGER_NAME: str = "radicale"
LOGGER_FORMAT: str = "[%(asctime)s] [%(ident)s] [%(levelname)s] %(message)s"
DATE_FORMAT: str = "%Y-%m-%d %H:%M:%S %z"
logger: logging.Logger = logging.getLogger(LOGGER_NAME)
class RemoveTracebackFilter(logging.Filter): class RemoveTracebackFilter(logging.Filter):
def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
record.exc_info = None record.exc_info = None
return True return True
REMOVE_TRACEBACK_FILTER = RemoveTracebackFilter() REMOVE_TRACEBACK_FILTER: logging.Filter = RemoveTracebackFilter()
class IdentLogRecordFactory: class IdentLogRecordFactory:
"""LogRecordFactory that adds ``ident`` attribute.""" """LogRecordFactory that adds ``ident`` attribute."""
def __init__(self, upstream_factory): def __init__(self, upstream_factory: Callable[..., logging.LogRecord]
self.upstream_factory = upstream_factory ) -> None:
self._upstream_factory = upstream_factory
def __call__(self, *args, **kwargs): def __call__(self, *args: Any, **kwargs: Any) -> logging.LogRecord:
record = self.upstream_factory(*args, **kwargs) record = self._upstream_factory(*args, **kwargs)
ident = "%d" % os.getpid() ident = "%d" % os.getpid()
main_thread = threading.main_thread() main_thread = threading.main_thread()
current_thread = threading.current_thread() current_thread = threading.current_thread()
if current_thread.name and main_thread != current_thread: if current_thread.name and main_thread != current_thread:
ident += "/%s" % current_thread.name ident += "/%s" % current_thread.name
record.ident = ident record.ident = ident # type:ignore[attr-defined]
return record return record
@ -68,13 +72,15 @@ class ThreadedStreamHandler(logging.Handler):
"""Sends logging output to the stream registered for the current thread or """Sends logging output to the stream registered for the current thread or
``sys.stderr`` when no stream was registered.""" ``sys.stderr`` when no stream was registered."""
terminator = "\n" terminator: ClassVar[str] = "\n"
def __init__(self): _streams: Dict[int, types.ErrorStream]
def __init__(self) -> None:
super().__init__() super().__init__()
self._streams = {} self._streams = {}
def emit(self, record): def emit(self, record: logging.LogRecord) -> None:
try: try:
stream = self._streams.get(threading.get_ident(), sys.stderr) stream = self._streams.get(threading.get_ident(), sys.stderr)
msg = self.format(record) msg = self.format(record)
@ -85,8 +91,8 @@ class ThreadedStreamHandler(logging.Handler):
except Exception: except Exception:
self.handleError(record) self.handleError(record)
@contextlib.contextmanager @types.contextmanager
def register_stream(self, stream): def register_stream(self, stream: types.ErrorStream) -> Iterator[None]:
"""Register stream for logging output of the current thread.""" """Register stream for logging output of the current thread."""
key = threading.get_ident() key = threading.get_ident()
self._streams[key] = stream self._streams[key] = stream
@ -96,13 +102,13 @@ class ThreadedStreamHandler(logging.Handler):
del self._streams[key] del self._streams[key]
@contextlib.contextmanager @types.contextmanager
def register_stream(stream): def register_stream(stream: types.ErrorStream) -> Iterator[None]:
"""Register stream for logging output of the current thread.""" """Register stream for logging output of the current thread."""
yield yield
def setup(): def setup() -> None:
"""Set global logging up.""" """Set global logging up."""
global register_stream global register_stream
handler = ThreadedStreamHandler() handler = ThreadedStreamHandler()
@ -114,12 +120,12 @@ def setup():
set_level(logging.WARNING) set_level(logging.WARNING)
def set_level(level): def set_level(level: Union[int, str]) -> None:
"""Set logging level for global logger.""" """Set logging level for global logger."""
if isinstance(level, str): if isinstance(level, str):
level = getattr(logging, level.upper()) level = getattr(logging, level.upper())
assert isinstance(level, int)
logger.setLevel(level) logger.setLevel(level)
if level == logging.DEBUG: logger.removeFilter(REMOVE_TRACEBACK_FILTER)
logger.removeFilter(REMOVE_TRACEBACK_FILTER) if level > logging.DEBUG:
else:
logger.addFilter(REMOVE_TRACEBACK_FILTER) logger.addFilter(REMOVE_TRACEBACK_FILTER)

View File

@ -21,20 +21,21 @@ Helper functions for working with the file system.
""" """
import contextlib
import os import os
import posixpath import posixpath
import sys import sys
import threading import threading
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Type, Union from typing import Iterator, Type, Union
if os.name == "nt": from radicale import storage, types
if sys.platform == "win32":
import ctypes import ctypes
import ctypes.wintypes import ctypes.wintypes
import msvcrt import msvcrt
LOCKFILE_EXCLUSIVE_LOCK = 2 LOCKFILE_EXCLUSIVE_LOCK: int = 2
ULONG_PTR: Union[Type[ctypes.c_uint32], Type[ctypes.c_uint64]] ULONG_PTR: Union[Type[ctypes.c_uint32], Type[ctypes.c_uint64]]
if ctypes.sizeof(ctypes.c_void_p) == 4: if ctypes.sizeof(ctypes.c_void_p) == 4:
ULONG_PTR = ctypes.c_uint32 ULONG_PTR = ctypes.c_uint32
@ -49,8 +50,7 @@ if os.name == "nt":
("offset_high", ctypes.wintypes.DWORD), ("offset_high", ctypes.wintypes.DWORD),
("h_event", ctypes.wintypes.HANDLE)] ("h_event", ctypes.wintypes.HANDLE)]
kernel32 = ctypes.WinDLL( # type: ignore[attr-defined] kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
"kernel32", use_last_error=True)
lock_file_ex = kernel32.LockFileEx lock_file_ex = kernel32.LockFileEx
lock_file_ex.argtypes = [ lock_file_ex.argtypes = [
ctypes.wintypes.HANDLE, ctypes.wintypes.HANDLE,
@ -71,13 +71,13 @@ if os.name == "nt":
elif os.name == "posix": elif os.name == "posix":
import fcntl import fcntl
HAVE_RENAMEAT2 = False HAVE_RENAMEAT2: bool = False
if sys.platform == "linux": if sys.platform == "linux":
import ctypes import ctypes
RENAME_EXCHANGE = 2 RENAME_EXCHANGE: int = 2
try: try:
renameat2 = ctypes.CDLL(None, use_errno=True).renameat2 renameat2 = ctypes.CDLL("", use_errno=True).renameat2
except AttributeError: except AttributeError:
pass pass
else: else:
@ -92,14 +92,19 @@ if sys.platform == "linux":
class RwLock: class RwLock:
"""A readers-Writer lock that locks a file.""" """A readers-Writer lock that locks a file."""
def __init__(self, path): _path: str
_readers: int
_writer: bool
_lock: threading.Lock
def __init__(self, path: str) -> None:
self._path = path self._path = path
self._readers = 0 self._readers = 0
self._writer = False self._writer = False
self._lock = threading.Lock() self._lock = threading.Lock()
@property @property
def locked(self): def locked(self) -> str:
with self._lock: with self._lock:
if self._readers > 0: if self._readers > 0:
return "r" return "r"
@ -107,12 +112,12 @@ class RwLock:
return "w" return "w"
return "" return ""
@contextlib.contextmanager @types.contextmanager
def acquire(self, mode): def acquire(self, mode: str) -> Iterator[None]:
if mode not in "rw": if mode not in "rw":
raise ValueError("Invalid mode: %r" % mode) raise ValueError("Invalid mode: %r" % mode)
with open(self._path, "w+") as lock_file: with open(self._path, "w+") as lock_file:
if os.name == "nt": if sys.platform == "win32":
handle = msvcrt.get_osfhandle(lock_file.fileno()) handle = msvcrt.get_osfhandle(lock_file.fileno())
flags = LOCKFILE_EXCLUSIVE_LOCK if mode == "w" else 0 flags = LOCKFILE_EXCLUSIVE_LOCK if mode == "w" else 0
overlapped = Overlapped() overlapped = Overlapped()
@ -120,15 +125,15 @@ class RwLock:
if not lock_file_ex(handle, flags, 0, 1, 0, overlapped): if not lock_file_ex(handle, flags, 0, 1, 0, overlapped):
raise ctypes.WinError() raise ctypes.WinError()
except OSError as e: except OSError as e:
raise RuntimeError("Locking the storage failed: %s" % raise RuntimeError("Locking the storage failed: %s" % e
e) from e ) from e
elif os.name == "posix": elif os.name == "posix":
_cmd = fcntl.LOCK_EX if mode == "w" else fcntl.LOCK_SH _cmd = fcntl.LOCK_EX if mode == "w" else fcntl.LOCK_SH
try: try:
fcntl.flock(lock_file.fileno(), _cmd) fcntl.flock(lock_file.fileno(), _cmd)
except OSError as e: except OSError as e:
raise RuntimeError("Locking the storage failed: %s" % raise RuntimeError("Locking the storage failed: %s" % e
e) from e ) from e
else: else:
raise RuntimeError("Locking the storage failed: " raise RuntimeError("Locking the storage failed: "
"Unsupported operating system") "Unsupported operating system")
@ -149,7 +154,7 @@ class RwLock:
self._writer = False self._writer = False
def rename_exchange(src, dst): def rename_exchange(src: str, dst: str) -> None:
"""Exchange the files or directories `src` and `dst`. """Exchange the files or directories `src` and `dst`.
Both `src` and `dst` must exist but may be of different types. Both `src` and `dst` must exist but may be of different types.
@ -181,26 +186,26 @@ def rename_exchange(src, dst):
finally: finally:
os.close(src_dir_fd) os.close(src_dir_fd)
else: else:
with TemporaryDirectory( with TemporaryDirectory(prefix=".Radicale.tmp-", dir=src_dir
prefix=".Radicale.tmp-", dir=src_dir) as tmp_dir: ) as tmp_dir:
os.rename(dst, os.path.join(tmp_dir, "interim")) os.rename(dst, os.path.join(tmp_dir, "interim"))
os.rename(src, dst) os.rename(src, dst)
os.rename(os.path.join(tmp_dir, "interim"), src) os.rename(os.path.join(tmp_dir, "interim"), src)
def fsync(fd): def fsync(fd: int) -> None:
if os.name == "posix" and hasattr(fcntl, "F_FULLFSYNC"): if os.name == "posix" and hasattr(fcntl, "F_FULLFSYNC"):
fcntl.fcntl(fd, fcntl.F_FULLFSYNC) fcntl.fcntl(fd, fcntl.F_FULLFSYNC)
else: else:
os.fsync(fd) os.fsync(fd)
def strip_path(path): def strip_path(path: str) -> str:
assert sanitize_path(path) == path assert sanitize_path(path) == path
return path.strip("/") return path.strip("/")
def unstrip_path(stripped_path, trailing_slash=False): def unstrip_path(stripped_path: str, trailing_slash: bool = False) -> str:
assert strip_path(sanitize_path(stripped_path)) == stripped_path assert strip_path(sanitize_path(stripped_path)) == stripped_path
assert stripped_path or trailing_slash assert stripped_path or trailing_slash
path = "/%s" % stripped_path path = "/%s" % stripped_path
@ -209,7 +214,7 @@ def unstrip_path(stripped_path, trailing_slash=False):
return path return path
def sanitize_path(path): def sanitize_path(path: str) -> str:
"""Make path absolute with leading slash to prevent access to other data. """Make path absolute with leading slash to prevent access to other data.
Preserve potential trailing slash. Preserve potential trailing slash.
@ -226,16 +231,16 @@ def sanitize_path(path):
return new_path + trailing_slash return new_path + trailing_slash
def is_safe_path_component(path): def is_safe_path_component(path: str) -> bool:
"""Check if path is a single component of a path. """Check if path is a single component of a path.
Check that the path is safe to join too. Check that the path is safe to join too.
""" """
return path and "/" not in path and path not in (".", "..") return bool(path) and "/" not in path and path not in (".", "..")
def is_safe_filesystem_path_component(path): def is_safe_filesystem_path_component(path: str) -> bool:
"""Check if path is a single component of a local and posix filesystem """Check if path is a single component of a local and posix filesystem
path. path.
@ -243,13 +248,13 @@ def is_safe_filesystem_path_component(path):
""" """
return ( return (
path and not os.path.splitdrive(path)[0] and bool(path) and not os.path.splitdrive(path)[0] and
not os.path.split(path)[0] and path not in (os.curdir, os.pardir) and not os.path.split(path)[0] and path not in (os.curdir, os.pardir) and
not path.startswith(".") and not path.endswith("~") and not path.startswith(".") and not path.endswith("~") and
is_safe_path_component(path)) is_safe_path_component(path))
def path_to_filesystem(root, sane_path): def path_to_filesystem(root: str, sane_path: str) -> str:
"""Convert `sane_path` to a local filesystem path relative to `root`. """Convert `sane_path` to a local filesystem path relative to `root`.
`root` must be a secure filesystem path, it will be prepend to the path. `root` must be a secure filesystem path, it will be prepend to the path.
@ -271,25 +276,25 @@ def path_to_filesystem(root, sane_path):
# Check for conflicting files (e.g. case-insensitive file systems # Check for conflicting files (e.g. case-insensitive file systems
# or short names on Windows file systems) # or short names on Windows file systems)
if (os.path.lexists(safe_path) and if (os.path.lexists(safe_path) and
part not in (e.name for e in part not in (e.name for e in os.scandir(safe_path_parent))):
os.scandir(safe_path_parent))):
raise CollidingPathError(part) raise CollidingPathError(part)
return safe_path return safe_path
class UnsafePathError(ValueError): class UnsafePathError(ValueError):
def __init__(self, path):
message = "Can't translate name safely to filesystem: %r" % path def __init__(self, path: str) -> None:
super().__init__(message) super().__init__("Can't translate name safely to filesystem: %r" %
path)
class CollidingPathError(ValueError): class CollidingPathError(ValueError):
def __init__(self, path):
message = "File name collision: %r" % path def __init__(self, path: str) -> None:
super().__init__(message) super().__init__("File name collision: %r" % path)
def name_from_path(path, collection): def name_from_path(path: str, collection: "storage.BaseCollection") -> str:
"""Return Radicale item name from ``path``.""" """Return Radicale item name from ``path``."""
assert sanitize_path(path) == path assert sanitize_path(path) == path
start = unstrip_path(collection.path, True) start = unstrip_path(collection.path, True)

View File

@ -32,17 +32,21 @@ Take a look at the class ``BaseRights`` if you want to implement your own.
""" """
from radicale import utils from typing import Sequence
INTERNAL_TYPES = ("authenticated", "owner_write", "owner_only", "from_file") from radicale import config, utils
INTERNAL_TYPES: Sequence[str] = ("authenticated", "owner_write", "owner_only",
"from_file")
def load(configuration): def load(configuration: "config.Configuration") -> "BaseRights":
"""Load the rights module chosen in configuration.""" """Load the rights module chosen in configuration."""
return utils.load_plugin(INTERNAL_TYPES, "rights", "Rights", configuration) return utils.load_plugin(INTERNAL_TYPES, "rights", "Rights", BaseRights,
configuration)
def intersect(a, b): def intersect(a: str, b: str) -> str:
"""Intersect two lists of rights. """Intersect two lists of rights.
Returns all rights that are both in ``a`` and ``b``. Returns all rights that are both in ``a`` and ``b``.
@ -52,7 +56,8 @@ def intersect(a, b):
class BaseRights: class BaseRights:
def __init__(self, configuration):
def __init__(self, configuration: "config.Configuration") -> None:
"""Initialize BaseRights. """Initialize BaseRights.
``configuration`` see ``radicale.config`` module. ``configuration`` see ``radicale.config`` module.
@ -62,7 +67,7 @@ class BaseRights:
""" """
self.configuration = configuration self.configuration = configuration
def authorization(self, user, path): def authorization(self, user: str, path: str) -> str:
"""Get granted rights of ``user`` for the collection ``path``. """Get granted rights of ``user`` for the collection ``path``.
If ``user`` is empty, check for anonymous rights. If ``user`` is empty, check for anonymous rights.

View File

@ -21,15 +21,16 @@ calendars and address books.
""" """
from radicale import pathutils, rights from radicale import config, pathutils, rights
class Rights(rights.BaseRights): class Rights(rights.BaseRights):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) def __init__(self, configuration: config.Configuration) -> None:
super().__init__(configuration)
self._verify_user = self.configuration.get("auth", "type") != "none" self._verify_user = self.configuration.get("auth", "type") != "none"
def authorization(self, user, path): def authorization(self, user: str, path: str) -> str:
if self._verify_user and not user: if self._verify_user and not user:
return "" return ""
sane_path = pathutils.strip_path(path) sane_path = pathutils.strip_path(path)

View File

@ -37,16 +37,19 @@ Leading or ending slashes are trimmed from collection's path.
import configparser import configparser
import re import re
from radicale import pathutils, rights from radicale import config, pathutils, rights
from radicale.log import logger from radicale.log import logger
class Rights(rights.BaseRights): class Rights(rights.BaseRights):
def __init__(self, configuration):
_filename: str
def __init__(self, configuration: config.Configuration) -> None:
super().__init__(configuration) super().__init__(configuration)
self._filename = configuration.get("rights", "file") self._filename = configuration.get("rights", "file")
def authorization(self, user, path): def authorization(self, user: str, path: str) -> str:
user = user or "" user = user or ""
sane_path = pathutils.strip_path(path) sane_path = pathutils.strip_path(path)
# Prevent "regex injection" # Prevent "regex injection"
@ -54,8 +57,7 @@ class Rights(rights.BaseRights):
rights_config = configparser.ConfigParser() rights_config = configparser.ConfigParser()
try: try:
if not rights_config.read(self._filename): if not rights_config.read(self._filename):
raise RuntimeError("No such file: %r" % raise RuntimeError("No such file: %r" % self._filename)
self._filename)
except Exception as e: except Exception as e:
raise RuntimeError("Failed to load rights file %r: %s" % raise RuntimeError("Failed to load rights file %r: %s" %
(self._filename, e)) from e (self._filename, e)) from e
@ -67,7 +69,7 @@ class Rights(rights.BaseRights):
user_match = re.fullmatch(user_pattern.format(), user) user_match = re.fullmatch(user_pattern.format(), user)
collection_match = user_match and re.fullmatch( collection_match = user_match and re.fullmatch(
collection_pattern.format( collection_pattern.format(
*map(re.escape, user_match.groups()), *(re.escape(s) for s in user_match.groups()),
user=escaped_user), sane_path) user=escaped_user), sane_path)
except Exception as e: except Exception as e:
raise RuntimeError("Error in section %r of rights file %r: " raise RuntimeError("Error in section %r of rights file %r: "

View File

@ -26,7 +26,8 @@ from radicale import pathutils
class Rights(authenticated.Rights): class Rights(authenticated.Rights):
def authorization(self, user, path):
def authorization(self, user: str, path: str) -> str:
if self._verify_user and not user: if self._verify_user and not user:
return "" return ""
sane_path = pathutils.strip_path(path) sane_path = pathutils.strip_path(path)

View File

@ -26,7 +26,8 @@ from radicale import pathutils
class Rights(authenticated.Rights): class Rights(authenticated.Rights):
def authorization(self, user, path):
def authorization(self, user: str, path: str) -> str:
if self._verify_user and not user: if self._verify_user and not user:
return "" return ""
sane_path = pathutils.strip_path(path) sane_path = pathutils.strip_path(path)

View File

@ -23,14 +23,15 @@ Built-in WSGI server.
""" """
import errno import errno
import os import http
import select import select
import socket import socket
import socketserver import socketserver
import ssl import ssl
import sys import sys
import wsgiref.simple_server import wsgiref.simple_server
from typing import MutableMapping from typing import (Any, Callable, Dict, List, MutableMapping, Optional, Set,
Tuple, Union)
from urllib.parse import unquote from urllib.parse import unquote
from radicale import Application, config from radicale import Application, config
@ -38,7 +39,7 @@ from radicale.log import logger
COMPAT_EAI_ADDRFAMILY: int COMPAT_EAI_ADDRFAMILY: int
if hasattr(socket, "EAI_ADDRFAMILY"): if hasattr(socket, "EAI_ADDRFAMILY"):
COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY # type: ignore[attr-defined] COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY # type:ignore[attr-defined]
elif hasattr(socket, "EAI_NONAME"): elif hasattr(socket, "EAI_NONAME"):
# Windows and BSD don't have a special error code for this # Windows and BSD don't have a special error code for this
COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME
@ -51,57 +52,99 @@ elif hasattr(socket, "EAI_NONAME"):
COMPAT_IPPROTO_IPV6: int COMPAT_IPPROTO_IPV6: int
if hasattr(socket, "IPPROTO_IPV6"): if hasattr(socket, "IPPROTO_IPV6"):
COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6 COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6
elif os.name == "nt": elif sys.platform == "win32":
# Workaround: https://bugs.python.org/issue29515 # HACK: https://bugs.python.org/issue29515
COMPAT_IPPROTO_IPV6 = 41 COMPAT_IPPROTO_IPV6 = 41
def format_address(address): # IPv4 (host, port) and IPv6 (host, port, flowinfo, scopeid)
ADDRESS_TYPE = Union[Tuple[str, int], Tuple[str, int, int, int]]
def format_address(address: ADDRESS_TYPE) -> str:
return "[%s]:%d" % address[:2] return "[%s]:%d" % address[:2]
class ParallelHTTPServer(socketserver.ThreadingMixIn, class ParallelHTTPServer(socketserver.ThreadingMixIn,
wsgiref.simple_server.WSGIServer): wsgiref.simple_server.WSGIServer):
# We wait for child threads ourself configuration: config.Configuration
block_on_close = False worker_sockets: Set[socket.socket]
daemon_threads = True _timeout: float
def __init__(self, configuration, family, address, RequestHandlerClass): # We wait for child threads ourself (ThreadingMixIn)
block_on_close: bool = False
daemon_threads: bool = True
def __init__(self, configuration: config.Configuration, family: int,
address: Tuple[str, int], RequestHandlerClass:
Callable[..., http.server.BaseHTTPRequestHandler]) -> None:
self.configuration = configuration self.configuration = configuration
self.address_family = family self.address_family = family
super().__init__(address, RequestHandlerClass) super().__init__(address, RequestHandlerClass)
self.client_sockets = set() self.worker_sockets = set()
self._timeout = configuration.get("server", "timeout")
def server_bind(self): def server_bind(self) -> None:
if self.address_family == socket.AF_INET6: if self.address_family == socket.AF_INET6:
# Only allow IPv6 connections to the IPv6 socket # Only allow IPv6 connections to the IPv6 socket
self.socket.setsockopt(COMPAT_IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) self.socket.setsockopt(COMPAT_IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
super().server_bind() super().server_bind()
def get_request(self): def get_request( # type:ignore[override]
self) -> Tuple[socket.socket, Tuple[ADDRESS_TYPE, socket.socket]]:
# Set timeout for client # Set timeout for client
request, client_address = super().get_request() request: socket.socket
timeout = self.configuration.get("server", "timeout") client_address: ADDRESS_TYPE
if timeout: request, client_address = super().get_request() # type:ignore[misc]
request.settimeout(timeout) if self._timeout > 0:
client_socket, client_socket_out = socket.socketpair() request.settimeout(self._timeout)
self.client_sockets.add(client_socket_out) worker_socket, worker_socket_out = socket.socketpair()
return request, (*client_address, client_socket) self.worker_sockets.add(worker_socket_out)
# HACK: Forward `worker_socket` via `client_address` return value
# to worker thread.
# The super class calls `verify_request`, `process_request` and
# `handle_error` with modified `client_address` value.
return request, (client_address, worker_socket)
def finish_request_locked(self, request, client_address): def verify_request( # type:ignore[override]
return super().finish_request(request, client_address) self, request: socket.socket, client_address_and_socket:
Tuple[ADDRESS_TYPE, socket.socket]) -> bool:
return True
def finish_request(self, request, client_address): def process_request( # type:ignore[override]
*client_address, client_socket = client_address self, request: socket.socket, client_address_and_socket:
client_address = tuple(client_address) Tuple[ADDRESS_TYPE, socket.socket]) -> None:
# HACK: Super class calls `finish_request` in new thread with
# `client_address_and_socket`
return super().process_request(
request, client_address_and_socket) # type:ignore[arg-type]
def finish_request( # type:ignore[override]
self, request: socket.socket, client_address_and_socket:
Tuple[ADDRESS_TYPE, socket.socket]) -> None:
# HACK: Unpack `client_address_and_socket` and call super class
# `finish_request` with original `client_address`
client_address, worker_socket = client_address_and_socket
try: try:
return self.finish_request_locked(request, client_address) return self.finish_request_locked(request, client_address)
finally: finally:
client_socket.close() worker_socket.close()
def handle_error(self, request, client_address): def finish_request_locked(self, request: socket.socket,
if issubclass(sys.exc_info()[0], socket.timeout): client_address: ADDRESS_TYPE) -> None:
return super().finish_request(
request, client_address) # type:ignore[arg-type]
def handle_error( # type:ignore[override]
self, request: socket.socket,
client_address_or_client_address_and_socket:
Union[ADDRESS_TYPE, Tuple[ADDRESS_TYPE, socket.socket]]) -> None:
# HACK: This method can be called with the modified
# `client_address_and_socket` or the original `client_address` value
e = sys.exc_info()[1]
assert e is not None
if isinstance(e, socket.timeout):
logger.info("Client timed out", exc_info=True) logger.info("Client timed out", exc_info=True)
else: else:
logger.error("An exception occurred during request: %s", logger.error("An exception occurred during request: %s",
@ -110,12 +153,12 @@ class ParallelHTTPServer(socketserver.ThreadingMixIn,
class ParallelHTTPSServer(ParallelHTTPServer): class ParallelHTTPSServer(ParallelHTTPServer):
def server_bind(self): def server_bind(self) -> None:
super().server_bind() super().server_bind()
# Wrap the TCP socket in an SSL socket # Wrap the TCP socket in an SSL socket
certfile = self.configuration.get("server", "certificate") certfile: str = self.configuration.get("server", "certificate")
keyfile = self.configuration.get("server", "key") keyfile: str = self.configuration.get("server", "key")
cafile = self.configuration.get("server", "certificate_authority") cafile: str = self.configuration.get("server", "certificate_authority")
# Test if the files can be read # Test if the files can be read
for name, filename in [("certificate", certfile), ("key", keyfile), for name, filename in [("certificate", certfile), ("key", keyfile),
("certificate_authority", cafile)]: ("certificate_authority", cafile)]:
@ -139,7 +182,9 @@ class ParallelHTTPSServer(ParallelHTTPServer):
self.socket = context.wrap_socket( self.socket = context.wrap_socket(
self.socket, server_side=True, do_handshake_on_connect=False) self.socket, server_side=True, do_handshake_on_connect=False)
def finish_request_locked(self, request, client_address): def finish_request_locked( # type:ignore[override]
self, request: ssl.SSLSocket, client_address: ADDRESS_TYPE
) -> None:
try: try:
try: try:
request.do_handshake() request.do_handshake()
@ -151,7 +196,7 @@ class ParallelHTTPSServer(ParallelHTTPServer):
try: try:
self.handle_error(request, client_address) self.handle_error(request, client_address)
finally: finally:
self.shutdown_request(request) self.shutdown_request(request) # type:ignore[attr-defined]
return return
return super().finish_request_locked(request, client_address) return super().finish_request_locked(request, client_address)
@ -161,30 +206,34 @@ class ServerHandler(wsgiref.simple_server.ServerHandler):
# Don't pollute WSGI environ with OS environment # Don't pollute WSGI environ with OS environment
os_environ: MutableMapping[str, str] = {} os_environ: MutableMapping[str, str] = {}
def log_exception(self, exc_info): def log_exception(self, exc_info: "wsgiref.handlers._exc_info") -> None:
logger.error("An exception occurred during request: %s", logger.error("An exception occurred during request: %s",
exc_info[1], exc_info=exc_info) exc_info[1], exc_info=exc_info) # type:ignore[arg-type]
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler): class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
"""HTTP requests handler.""" """HTTP requests handler."""
def log_request(self, code="-", size="-"): # HACK: Assigned in `socketserver.StreamRequestHandler`
connection: socket.socket
def log_request(self, code: Union[int, str] = "-",
size: Union[int, str] = "-") -> None:
pass # Disable request logging. pass # Disable request logging.
def log_error(self, format_, *args): def log_error(self, format_: str, *args: Any) -> None:
logger.error("An error occurred during request: %s", format_ % args) logger.error("An error occurred during request: %s", format_ % args)
def get_environ(self): def get_environ(self) -> Dict[str, Any]:
env = super().get_environ() env = super().get_environ()
if hasattr(self.connection, "getpeercert"): if isinstance(self.connection, ssl.SSLSocket):
# The certificate can be evaluated by the auth module # The certificate can be evaluated by the auth module
env["REMOTE_CERTIFICATE"] = self.connection.getpeercert() env["REMOTE_CERTIFICATE"] = self.connection.getpeercert()
# Parent class only tries latin1 encoding # Parent class only tries latin1 encoding
env["PATH_INFO"] = unquote(self.path.split("?", 1)[0]) env["PATH_INFO"] = unquote(self.path.split("?", 1)[0])
return env return env
def handle(self): def handle(self) -> None:
"""Copy of WSGIRequestHandler.handle with different ServerHandler""" """Copy of WSGIRequestHandler.handle with different ServerHandler"""
self.raw_requestline = self.rfile.readline(65537) self.raw_requestline = self.rfile.readline(65537)
@ -201,11 +250,13 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
handler = ServerHandler( handler = ServerHandler(
self.rfile, self.wfile, self.get_stderr(), self.get_environ() self.rfile, self.wfile, self.get_stderr(), self.get_environ()
) )
handler.request_handler = self handler.request_handler = self # type:ignore[attr-defined]
handler.run(self.server.get_app()) app = self.server.get_app() # type:ignore[attr-defined]
handler.run(app)
def serve(configuration, shutdown_socket=None): def serve(configuration: config.Configuration,
shutdown_socket: Optional[socket.socket] = None) -> None:
"""Serve radicale from configuration. """Serve radicale from configuration.
`shutdown_socket` can be used to gracefully shutdown the server. `shutdown_socket` can be used to gracefully shutdown the server.
@ -221,12 +272,13 @@ def serve(configuration, shutdown_socket=None):
configuration.update({"server": {"_internal_server": "True"}}, "server", configuration.update({"server": {"_internal_server": "True"}}, "server",
privileged=True) privileged=True)
use_ssl = configuration.get("server", "ssl") use_ssl: bool = configuration.get("server", "ssl")
server_class = ParallelHTTPSServer if use_ssl else ParallelHTTPServer server_class = ParallelHTTPSServer if use_ssl else ParallelHTTPServer
application = Application(configuration) application = Application(configuration)
servers = {} servers = {}
try: try:
for address in configuration.get("server", "hosts"): hosts: List[Tuple[str, int]] = configuration.get("server", "hosts")
for address in hosts:
# Try to bind sockets for IPv4 and IPv6 # Try to bind sockets for IPv4 and IPv6
possible_families = (socket.AF_INET, socket.AF_INET6) possible_families = (socket.AF_INET, socket.AF_INET6)
bind_ok = False bind_ok = False
@ -270,16 +322,16 @@ def serve(configuration, shutdown_socket=None):
# Mainloop # Mainloop
select_timeout = None select_timeout = None
if os.name == "nt": if sys.platform == "win32":
# Fallback to busy waiting. (select(...) blocks SIGINT on Windows.) # Fallback to busy waiting. (select(...) blocks SIGINT on Windows.)
select_timeout = 1.0 select_timeout = 1.0
max_connections = configuration.get("server", "max_connections") max_connections: int = configuration.get("server", "max_connections")
logger.info("Radicale server ready") logger.info("Radicale server ready")
while True: while True:
rlist = [] rlist: List[socket.socket] = []
# Wait for finished clients # Wait for finished clients
for server in servers.values(): for server in servers.values():
rlist.extend(server.client_sockets) rlist.extend(server.worker_sockets)
# Accept new connections if max_connections is not reached # Accept new connections if max_connections is not reached
if max_connections <= 0 or len(rlist) < max_connections: if max_connections <= 0 or len(rlist) < max_connections:
rlist.extend(servers) rlist.extend(servers)
@ -287,26 +339,26 @@ def serve(configuration, shutdown_socket=None):
if shutdown_socket is not None: if shutdown_socket is not None:
rlist.append(shutdown_socket) rlist.append(shutdown_socket)
rlist, _, _ = select.select(rlist, [], [], select_timeout) rlist, _, _ = select.select(rlist, [], [], select_timeout)
rlist = set(rlist) rset = set(rlist)
if shutdown_socket in rlist: if shutdown_socket in rset:
logger.info("Stopping Radicale") logger.info("Stopping Radicale")
break break
for server in servers.values(): for server in servers.values():
finished_sockets = server.client_sockets.intersection(rlist) finished_sockets = server.worker_sockets.intersection(rset)
for s in finished_sockets: for s in finished_sockets:
s.close() s.close()
server.client_sockets.remove(s) server.worker_sockets.remove(s)
rlist.remove(s) rset.remove(s)
if finished_sockets: if finished_sockets:
server.service_actions() server.service_actions()
if rlist: if rset:
server = servers.get(rlist.pop()) active_server = servers.get(rset.pop())
if server: if active_server:
server.handle_request() active_server.handle_request()
finally: finally:
# Wait for clients to finish and close servers # Wait for clients to finish and close servers
for server in servers.values(): for server in servers.values():
for s in server.client_sockets: for s in server.worker_sockets:
s.recv(1) s.recv(1)
s.close() s.close()
server.server_close() server.server_close()

View File

@ -23,37 +23,44 @@ Take a look at the class ``BaseCollection`` if you want to implement your own.
""" """
import contextlib
import json import json
import xml.etree.ElementTree as ET
from hashlib import sha256 from hashlib import sha256
from typing import (Iterable, Iterator, Mapping, Optional, Sequence, Set,
Tuple, Union, overload)
import pkg_resources import pkg_resources
import vobject import vobject
from radicale import utils from radicale import config
from radicale import item as radicale_item
from radicale import types, utils
from radicale.item import filter as radicale_filter from radicale.item import filter as radicale_filter
INTERNAL_TYPES = ("multifilesystem",) INTERNAL_TYPES: Sequence[str] = ("multifilesystem",)
CACHE_DEPS = ("radicale", "vobject", "python-dateutil",) CACHE_DEPS: Sequence[str] = ("radicale", "vobject", "python-dateutil",)
CACHE_VERSION = (";".join(pkg_resources.get_distribution(pkg).version CACHE_VERSION: bytes = "".join(
for pkg in CACHE_DEPS) + ";").encode() "%s=%s;" % (pkg, pkg_resources.get_distribution(pkg).version)
for pkg in CACHE_DEPS).encode()
def load(configuration): def load(configuration: "config.Configuration") -> "BaseStorage":
"""Load the storage module chosen in configuration.""" """Load the storage module chosen in configuration."""
return utils.load_plugin( return utils.load_plugin(INTERNAL_TYPES, "storage", "Storage", BaseStorage,
INTERNAL_TYPES, "storage", "Storage", configuration) configuration)
class ComponentExistsError(ValueError): class ComponentExistsError(ValueError):
def __init__(self, path):
def __init__(self, path: str) -> None:
message = "Component already exists: %r" % path message = "Component already exists: %r" % path
super().__init__(message) super().__init__(message)
class ComponentNotFoundError(ValueError): class ComponentNotFoundError(ValueError):
def __init__(self, path):
def __init__(self, path: str) -> None:
message = "Component doesn't exist: %r" % path message = "Component doesn't exist: %r" % path
super().__init__(message) super().__init__(message)
@ -61,47 +68,58 @@ class ComponentNotFoundError(ValueError):
class BaseCollection: class BaseCollection:
@property @property
def path(self): def path(self) -> str:
"""The sanitized path of the collection without leading or """The sanitized path of the collection without leading or
trailing ``/``.""" trailing ``/``."""
raise NotImplementedError raise NotImplementedError
@property @property
def owner(self): def owner(self) -> str:
"""The owner of the collection.""" """The owner of the collection."""
return self.path.split("/", maxsplit=1)[0] return self.path.split("/", maxsplit=1)[0]
@property @property
def is_principal(self): def is_principal(self) -> bool:
"""Collection is a principal.""" """Collection is a principal."""
return bool(self.path) and "/" not in self.path return bool(self.path) and "/" not in self.path
@property @property
def etag(self): def etag(self) -> str:
"""Encoded as quoted-string (see RFC 2616).""" """Encoded as quoted-string (see RFC 2616)."""
etag = sha256() etag = sha256()
for item in self.get_all(): for item in self.get_all():
assert item.href
etag.update((item.href + "/" + item.etag).encode()) etag.update((item.href + "/" + item.etag).encode())
etag.update(json.dumps(self.get_meta(), sort_keys=True).encode()) etag.update(json.dumps(self.get_meta(), sort_keys=True).encode())
return '"%s"' % etag.hexdigest() return '"%s"' % etag.hexdigest()
def sync(self, old_token=None): @property
def tag(self) -> str:
"""The tag of the collection."""
return self.get_meta("tag") or ""
def sync(self, old_token: str = "") -> Tuple[str, Iterable[str]]:
"""Get the current sync token and changed items for synchronization. """Get the current sync token and changed items for synchronization.
``old_token`` an old sync token which is used as the base of the ``old_token`` an old sync token which is used as the base of the
delta update. If sync token is missing, all items are returned. delta update. If sync token is empty, all items are returned.
ValueError is raised for invalid or old tokens. ValueError is raised for invalid or old tokens.
WARNING: This simple default implementation treats all sync-token as WARNING: This simple default implementation treats all sync-token as
invalid. invalid.
""" """
def hrefs_iter() -> Iterator[str]:
for item in self.get_all():
assert item.href
yield item.href
token = "http://radicale.org/ns/sync/%s" % self.etag.strip("\"") token = "http://radicale.org/ns/sync/%s" % self.etag.strip("\"")
if old_token: if old_token:
raise ValueError("Sync token are not supported") raise ValueError("Sync token are not supported")
return token, (item.href for item in self.get_all()) return token, hrefs_iter()
def get_multi(self, hrefs): def get_multi(self, hrefs: Iterable[str]
) -> Iterable[Tuple[str, Optional["radicale_item.Item"]]]:
"""Fetch multiple items. """Fetch multiple items.
It's not required to return the requested items in the correct order. It's not required to return the requested items in the correct order.
@ -113,11 +131,12 @@ class BaseCollection:
""" """
raise NotImplementedError raise NotImplementedError
def get_all(self): def get_all(self) -> Iterable["radicale_item.Item"]:
"""Fetch all items.""" """Fetch all items."""
raise NotImplementedError raise NotImplementedError
def get_filtered(self, filters): def get_filtered(self, filters: Iterable[ET.Element]
) -> Iterable[Tuple["radicale_item.Item", bool]]:
"""Fetch all items with optional filtering. """Fetch all items with optional filtering.
This can largely improve performance of reports depending on This can largely improve performance of reports depending on
@ -128,32 +147,31 @@ class BaseCollection:
matched. matched.
""" """
if not self.tag:
return
tag, start, end, simple = radicale_filter.simplify_prefilters( tag, start, end, simple = radicale_filter.simplify_prefilters(
filters, collection_tag=self.get_meta("tag")) filters, self.tag)
for item in self.get_all(): for item in self.get_all():
if tag: if tag is not None and tag != item.component_name:
if tag != item.component_name: continue
continue istart, iend = item.time_range
istart, iend = item.time_range if istart >= end or iend <= start:
if istart >= end or iend <= start: continue
continue yield item, simple and (start <= istart or iend <= end)
item_simple = simple and (start <= istart or iend <= end)
else:
item_simple = simple
yield item, item_simple
def has_uid(self, uid): def has_uid(self, uid: str) -> bool:
"""Check if a UID exists in the collection.""" """Check if a UID exists in the collection."""
for item in self.get_all(): for item in self.get_all():
if item.uid == uid: if item.uid == uid:
return True return True
return False return False
def upload(self, href, item): def upload(self, href: str, item: "radicale_item.Item") -> (
"radicale_item.Item"):
"""Upload a new or replace an existing item.""" """Upload a new or replace an existing item."""
raise NotImplementedError raise NotImplementedError
def delete(self, href=None): def delete(self, href: Optional[str] = None) -> None:
"""Delete an item. """Delete an item.
When ``href`` is ``None``, delete the collection. When ``href`` is ``None``, delete the collection.
@ -161,7 +179,14 @@ class BaseCollection:
""" """
raise NotImplementedError raise NotImplementedError
def get_meta(self, key=None): @overload
def get_meta(self, key: None = None) -> Mapping[str, str]: ...
@overload
def get_meta(self, key: str) -> Optional[str]: ...
def get_meta(self, key: Optional[str] = None
) -> Union[Mapping[str, str], Optional[str]]:
"""Get metadata value for collection. """Get metadata value for collection.
Return the value of the property ``key``. If ``key`` is ``None`` return Return the value of the property ``key``. If ``key`` is ``None`` return
@ -170,7 +195,7 @@ class BaseCollection:
""" """
raise NotImplementedError raise NotImplementedError
def set_meta(self, props): def set_meta(self, props: Mapping[str, str]) -> None:
"""Set metadata values for collection. """Set metadata values for collection.
``props`` a dict with values for properties. ``props`` a dict with values for properties.
@ -179,16 +204,16 @@ class BaseCollection:
raise NotImplementedError raise NotImplementedError
@property @property
def last_modified(self): def last_modified(self) -> str:
"""Get the HTTP-datetime of when the collection was modified.""" """Get the HTTP-datetime of when the collection was modified."""
raise NotImplementedError raise NotImplementedError
def serialize(self): def serialize(self) -> str:
"""Get the unicode string representing the whole collection.""" """Get the unicode string representing the whole collection."""
if self.get_meta("tag") == "VCALENDAR": if self.tag == "VCALENDAR":
in_vcalendar = False in_vcalendar = False
vtimezones = "" vtimezones = ""
included_tzids = set() included_tzids: Set[str] = set()
vtimezone = [] vtimezone = []
tzid = None tzid = None
components = "" components = ""
@ -216,6 +241,7 @@ class BaseCollection:
elif depth == 2 and line.startswith("END:"): elif depth == 2 and line.startswith("END:"):
if tzid is None or tzid not in included_tzids: if tzid is None or tzid not in included_tzids:
vtimezones += "".join(vtimezone) vtimezones += "".join(vtimezone)
if tzid is not None:
included_tzids.add(tzid) included_tzids.add(tzid)
vtimezone.clear() vtimezone.clear()
tzid = None tzid = None
@ -240,13 +266,14 @@ class BaseCollection:
return (template[:template_insert_pos] + return (template[:template_insert_pos] +
vtimezones + components + vtimezones + components +
template[template_insert_pos:]) template[template_insert_pos:])
if self.get_meta("tag") == "VADDRESSBOOK": if self.tag == "VADDRESSBOOK":
return "".join((item.serialize() for item in self.get_all())) return "".join((item.serialize() for item in self.get_all()))
return "" return ""
class BaseStorage: class BaseStorage:
def __init__(self, configuration):
def __init__(self, configuration: "config.Configuration") -> None:
"""Initialize BaseStorage. """Initialize BaseStorage.
``configuration`` see ``radicale.config`` module. ``configuration`` see ``radicale.config`` module.
@ -256,7 +283,8 @@ class BaseStorage:
""" """
self.configuration = configuration self.configuration = configuration
def discover(self, path, depth="0"): def discover(self, path: str, depth: str = "0") -> Iterable[
"types.CollectionOrItem"]:
"""Discover a list of collections under the given ``path``. """Discover a list of collections under the given ``path``.
``path`` is sanitized. ``path`` is sanitized.
@ -272,7 +300,8 @@ class BaseStorage:
""" """
raise NotImplementedError raise NotImplementedError
def move(self, item, to_collection, to_href): def move(self, item: "radicale_item.Item", to_collection: BaseCollection,
to_href: str) -> None:
"""Move an object. """Move an object.
``item`` is the item to move. ``item`` is the item to move.
@ -285,7 +314,10 @@ class BaseStorage:
""" """
raise NotImplementedError raise NotImplementedError
def create_collection(self, href, items=None, props=None): def create_collection(
self, href: str,
items: Optional[Iterable["radicale_item.Item"]] = None,
props: Optional[Mapping[str, str]] = None) -> BaseCollection:
"""Create a collection. """Create a collection.
``href`` is the sanitized path. ``href`` is the sanitized path.
@ -298,15 +330,14 @@ class BaseStorage:
``props`` are metadata values for the collection. ``props`` are metadata values for the collection.
``props["tag"]`` is the type of collection (VCALENDAR or ``props["tag"]`` is the type of collection (VCALENDAR or VADDRESSBOOK).
VADDRESSBOOK). If the key ``tag`` is missing, it is guessed from the If the key ``tag`` is missing, ``items`` is ignored.
collection.
""" """
raise NotImplementedError raise NotImplementedError
@contextlib.contextmanager @types.contextmanager
def acquire_lock(self, mode, user=None): def acquire_lock(self, mode: str, user: str = "") -> Iterator[None]:
"""Set a context manager to lock the whole storage. """Set a context manager to lock the whole storage.
``mode`` must either be "r" for shared access or "w" for exclusive ``mode`` must either be "r" for shared access or "w" for exclusive
@ -317,6 +348,6 @@ class BaseStorage:
""" """
raise NotImplementedError raise NotImplementedError
def verify(self): def verify(self) -> bool:
"""Check the storage for errors.""" """Check the storage for errors."""
raise NotImplementedError raise NotImplementedError

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Radicale. If not, see <http://www.gnu.org/licenses/>. # along with Radicale. If not, see <http://www.gnu.org/licenses/>.
import contextlib
import os import os
import pickle import pickle
import time import time
@ -72,14 +73,10 @@ class CollectionCacheMixin:
"item") "item")
content = self._item_cache_content(item, cache_hash) content = self._item_cache_content(item, cache_hash)
self._storage._makedirs_synced(cache_folder) self._storage._makedirs_synced(cache_folder)
try: # Race: Other processes might have created and locked the file.
# Race: Other processes might have created and locked the with contextlib.suppress(PermissionError), self._atomic_write(
# file. os.path.join(cache_folder, href), "wb") as f:
with self._atomic_write(os.path.join(cache_folder, href), pickle.dump(content, f)
"wb") as f:
pickle.dump(content, f)
except PermissionError:
pass
return content return content
def _load_item_cache(self, href, input_hash): def _load_item_cache(self, href, input_hash):

View File

@ -17,11 +17,12 @@
# along with Radicale. If not, see <http://www.gnu.org/licenses/>. # along with Radicale. If not, see <http://www.gnu.org/licenses/>.
import os import os
import sys
import time import time
import vobject import vobject
from radicale import item as radicale_item import radicale.item as radicale_item
from radicale import pathutils from radicale import pathutils
from radicale.log import logger from radicale.log import logger
@ -63,7 +64,7 @@ class CollectionGetMixin:
return None return None
except PermissionError: except PermissionError:
# Windows raises ``PermissionError`` when ``path`` is a directory # Windows raises ``PermissionError`` when ``path`` is a directory
if (os.name == "nt" and if (sys.platform == "win32" and
os.path.isdir(path) and os.access(path, os.R_OK)): os.path.isdir(path) and os.access(path, os.R_OK)):
return None return None
raise raise
@ -83,10 +84,10 @@ class CollectionGetMixin:
self._load_item_cache(href, input_hash) self._load_item_cache(href, input_hash)
if input_hash != cache_hash: if input_hash != cache_hash:
try: try:
vobject_items = tuple(vobject.readComponents( vobject_items = list(vobject.readComponents(
raw_text.decode(self._encoding))) raw_text.decode(self._encoding)))
radicale_item.check_and_sanitize_items( radicale_item.check_and_sanitize_items(
vobject_items, tag=self.get_meta("tag")) vobject_items, tag=self.tag)
vobject_item, = vobject_items vobject_item, = vobject_items
temp_item = radicale_item.Item( temp_item = radicale_item.Item(
collection=self, vobject_item=vobject_item) collection=self, vobject_item=vobject_item)

View File

@ -17,10 +17,11 @@
# along with Radicale. If not, see <http://www.gnu.org/licenses/>. # along with Radicale. If not, see <http://www.gnu.org/licenses/>.
import binascii import binascii
import contextlib
import os import os
import pickle import pickle
from radicale import item as radicale_item import radicale.item as radicale_item
from radicale import pathutils from radicale import pathutils
from radicale.log import logger from radicale.log import logger
@ -53,13 +54,10 @@ class CollectionHistoryMixin:
self._storage._makedirs_synced(history_folder) self._storage._makedirs_synced(history_folder)
history_etag = radicale_item.get_etag( history_etag = radicale_item.get_etag(
history_etag + "/" + etag).strip("\"") history_etag + "/" + etag).strip("\"")
try: # Race: Other processes might have created and locked the file.
# Race: Other processes might have created and locked the file. with contextlib.suppress(PermissionError), self._atomic_write(
with self._atomic_write(os.path.join(history_folder, href), os.path.join(history_folder, href), "wb") as f:
"wb") as f: pickle.dump([etag, history_etag], f)
pickle.dump([etag, history_etag], f)
except PermissionError:
pass
return history_etag return history_etag
def _get_deleted_history_hrefs(self): def _get_deleted_history_hrefs(self):
@ -67,7 +65,7 @@ class CollectionHistoryMixin:
history cache.""" history cache."""
history_folder = os.path.join(self._filesystem_path, history_folder = os.path.join(self._filesystem_path,
".Radicale.cache", "history") ".Radicale.cache", "history")
try: with contextlib.suppress(FileNotFoundError):
for entry in os.scandir(history_folder): for entry in os.scandir(history_folder):
href = entry.name href = entry.name
if not pathutils.is_safe_filesystem_path_component(href): if not pathutils.is_safe_filesystem_path_component(href):
@ -75,8 +73,6 @@ class CollectionHistoryMixin:
if os.path.isfile(os.path.join(self._filesystem_path, href)): if os.path.isfile(os.path.join(self._filesystem_path, href)):
continue continue
yield href yield href
except FileNotFoundError:
pass
def _clean_history(self): def _clean_history(self):
# Delete all expired history entries of deleted items. # Delete all expired history entries of deleted items.

View File

@ -22,6 +22,7 @@ import os
import shlex import shlex
import signal import signal
import subprocess import subprocess
import sys
from radicale import pathutils from radicale import pathutils
from radicale.log import logger from radicale.log import logger
@ -48,7 +49,7 @@ class StorageLockMixin:
self._lock = pathutils.RwLock(lock_path) self._lock = pathutils.RwLock(lock_path)
@contextlib.contextmanager @contextlib.contextmanager
def acquire_lock(self, mode, user=None): def acquire_lock(self, mode, user=""):
with self._lock.acquire(mode): with self._lock.acquire(mode):
yield yield
# execute hook # execute hook
@ -66,7 +67,7 @@ class StorageLockMixin:
if os.name == "posix": if os.name == "posix":
# Process group is also used to identify child processes # Process group is also used to identify child processes
popen_kwargs["preexec_fn"] = os.setpgrp popen_kwargs["preexec_fn"] = os.setpgrp
elif os.name == "nt": elif sys.platform == "win32":
popen_kwargs["creationflags"] = ( popen_kwargs["creationflags"] = (
subprocess.CREATE_NEW_PROCESS_GROUP) subprocess.CREATE_NEW_PROCESS_GROUP)
command = hook % {"user": shlex.quote(user or "Anonymous")} command = hook % {"user": shlex.quote(user or "Anonymous")}

View File

@ -19,7 +19,7 @@
import json import json
import os import os
from radicale import item as radicale_item import radicale.item as radicale_item
class CollectionMetaMixin: class CollectionMetaMixin:
@ -35,14 +35,15 @@ class CollectionMetaMixin:
try: try:
try: try:
with open(self._props_path, encoding=self._encoding) as f: with open(self._props_path, encoding=self._encoding) as f:
self._meta_cache = json.load(f) temp_meta = json.load(f)
except FileNotFoundError: except FileNotFoundError:
self._meta_cache = {} temp_meta = {}
radicale_item.check_and_sanitize_props(self._meta_cache) self._meta_cache = radicale_item.check_and_sanitize_props(
temp_meta)
except ValueError as e: except ValueError as e:
raise RuntimeError("Failed to load properties of collection " raise RuntimeError("Failed to load properties of collection "
"%r: %s" % (self.path, e)) from e "%r: %s" % (self.path, e)) from e
return self._meta_cache.get(key) if key else self._meta_cache return self._meta_cache if key is None else self._meta_cache.get(key)
def set_meta(self, props): def set_meta(self, props):
with self._atomic_write(self._props_path, "w") as f: with self._atomic_write(self._props_path, "w") as f:

View File

@ -16,6 +16,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Radicale. If not, see <http://www.gnu.org/licenses/>. # along with Radicale. If not, see <http://www.gnu.org/licenses/>.
import contextlib
import itertools import itertools
import os import os
import pickle import pickle
@ -25,7 +26,7 @@ from radicale.log import logger
class CollectionSyncMixin: class CollectionSyncMixin:
def sync(self, old_token=None): def sync(self, old_token=""):
# The sync token has the form http://radicale.org/ns/sync/TOKEN_NAME # The sync token has the form http://radicale.org/ns/sync/TOKEN_NAME
# where TOKEN_NAME is the sha256 hash of all history etags of present # where TOKEN_NAME is the sha256 hash of all history etags of present
# and past items of the collection. # and past items of the collection.
@ -37,7 +38,7 @@ class CollectionSyncMixin:
return False return False
return True return True
old_token_name = None old_token_name = ""
if old_token: if old_token:
# Extract the token name from the sync token # Extract the token name from the sync token
if not old_token.startswith("http://radicale.org/ns/sync/"): if not old_token.startswith("http://radicale.org/ns/sync/"):
@ -78,10 +79,9 @@ class CollectionSyncMixin:
"Failed to load stored sync token %r in %r: %s", "Failed to load stored sync token %r in %r: %s",
old_token_name, self.path, e, exc_info=True) old_token_name, self.path, e, exc_info=True)
# Delete the damaged file # Delete the damaged file
try: with contextlib.suppress(FileNotFoundError,
PermissionError):
os.remove(old_token_path) os.remove(old_token_path)
except (FileNotFoundError, PermissionError):
pass
raise ValueError("Token not found: %r" % old_token) raise ValueError("Token not found: %r" % old_token)
# write the new token state or update the modification time of # write the new token state or update the modification time of
# existing token state # existing token state
@ -101,11 +101,9 @@ class CollectionSyncMixin:
self._clean_history() self._clean_history()
else: else:
# Try to update the modification time # Try to update the modification time
try: with contextlib.suppress(FileNotFoundError):
# Race: Another process might have deleted the file. # Race: Another process might have deleted the file.
os.utime(token_path) os.utime(token_path)
except FileNotFoundError:
pass
changes = [] changes = []
# Find all new, changed and deleted (that are still in the item cache) # Find all new, changed and deleted (that are still in the item cache)
# items # items

View File

@ -18,8 +18,9 @@
import os import os
import pickle import pickle
import sys
from radicale import item as radicale_item import radicale.item as radicale_item
from radicale import pathutils from radicale import pathutils
@ -63,7 +64,7 @@ class CollectionUploadMixin:
"Failed to store item %r in temporary collection %r: %s" % "Failed to store item %r in temporary collection %r: %s" %
(uid, self.path, e)) from e (uid, self.path, e)) from e
href_candidate_funtions = [] href_candidate_funtions = []
if os.name in ("nt", "posix"): if os.name == "posix" or sys.platform == "win32":
href_candidate_funtions.append( href_candidate_funtions.append(
lambda: uid if uid.lower().endswith(suffix.lower()) lambda: uid if uid.lower().endswith(suffix.lower())
else uid + suffix) else uid + suffix)
@ -88,7 +89,7 @@ class CollectionUploadMixin:
except OSError as e: except OSError as e:
if href_candidate_funtions and ( if href_candidate_funtions and (
os.name == "posix" and e.errno == 22 or os.name == "posix" and e.errno == 22 or
os.name == "nt" and e.errno == 123): sys.platform == "win32" and e.errno == 123):
continue continue
raise raise
with f: with f:

View File

@ -67,8 +67,8 @@ class StorageVerifyMixin:
item.href, sane_path) item.href, sane_path)
if item_errors == saved_item_errors: if item_errors == saved_item_errors:
collection.sync() collection.sync()
if has_child_collections and collection.get_meta("tag"): if has_child_collections and collection.tag:
logger.error("Invalid collection %r: %r must not have " logger.error("Invalid collection %r: %r must not have "
"child collections", sane_path, "child collections", sane_path,
collection.get_meta("tag")) collection.tag)
return item_errors == 0 and collection_errors == 0 return item_errors == 0 and collection_errors == 0

View File

@ -119,7 +119,7 @@ class BaseTest:
if not self._check_status(status, 207, check): if not self._check_status(status, 207, check):
return status, None return status, None
responses = self.parse_responses(answer) responses = self.parse_responses(answer)
if args.get("HTTP_DEPTH", 0) == 0: if args.get("HTTP_DEPTH", "0") == "0":
assert len(responses) == 1 and path in responses assert len(responses) == 1 and path in responses
return status, responses return status, responses

View File

@ -23,6 +23,7 @@ Radicale tests with simple requests and authentication.
import os import os
import shutil import shutil
import sys
import tempfile import tempfile
import pytest import pytest
@ -114,7 +115,7 @@ class TestBaseAuthRequests(BaseTest):
def test_htpasswd_multi(self): def test_htpasswd_multi(self):
self._test_htpasswd("plain", "ign:ign\ntmp:bepo") self._test_htpasswd("plain", "ign:ign\ntmp:bepo")
@pytest.mark.skipif(os.name == "nt", reason="leading and trailing " @pytest.mark.skipif(sys.platform == "win32", reason="leading and trailing "
"whitespaces not allowed in file names") "whitespaces not allowed in file names")
def test_htpasswd_whitespace_user(self): def test_htpasswd_whitespace_user(self):
for user in (" tmp", "tmp ", " tmp "): for user in (" tmp", "tmp ", " tmp "):

View File

@ -391,10 +391,10 @@ class BaseRequestsMixIn:
event = get_file_content("event1.ics") event = get_file_content("event1.ics")
event_path = posixpath.join(calendar_path, "event.ics") event_path = posixpath.join(calendar_path, "event.ics")
self.put(event_path, event) self.put(event_path, event)
_, responses = self.propfind("/", HTTP_DEPTH=1) _, responses = self.propfind("/", HTTP_DEPTH="1")
assert len(responses) == 2 assert len(responses) == 2
assert "/" in responses and calendar_path in responses assert "/" in responses and calendar_path in responses
_, responses = self.propfind(calendar_path, HTTP_DEPTH=1) _, responses = self.propfind(calendar_path, HTTP_DEPTH="1")
assert len(responses) == 2 assert len(responses) == 2
assert calendar_path in responses and event_path in responses assert calendar_path in responses and event_path in responses
@ -1653,8 +1653,8 @@ class TestMultiFileSystem(BaseFileSystemTest, BaseRequestsMixIn):
assert answer1 == answer2 assert answer1 == answer2
assert os.path.exists(os.path.join(cache_folder, "event1.ics")) assert os.path.exists(os.path.join(cache_folder, "event1.ics"))
@pytest.mark.skipif(os.name not in ("nt", "posix"), @pytest.mark.skipif(os.name != "posix" and sys.platform != "win32",
reason="Only supported on 'nt' and 'posix'") reason="Only supported on 'posix' and 'win32'")
def test_put_whole_calendar_uids_used_as_file_names(self): def test_put_whole_calendar_uids_used_as_file_names(self):
"""Test if UIDs are used as file names.""" """Test if UIDs are used as file names."""
BaseRequestsMixIn.test_put_whole_calendar(self) BaseRequestsMixIn.test_put_whole_calendar(self)
@ -1662,8 +1662,8 @@ class TestMultiFileSystem(BaseFileSystemTest, BaseRequestsMixIn):
_, answer = self.get("/calendar.ics/%s.ics" % uid) _, answer = self.get("/calendar.ics/%s.ics" % uid)
assert "\r\nUID:%s\r\n" % uid in answer assert "\r\nUID:%s\r\n" % uid in answer
@pytest.mark.skipif(os.name not in ("nt", "posix"), @pytest.mark.skipif(os.name != "posix" and sys.platform != "win32",
reason="Only supported on 'nt' and 'posix'") reason="Only supported on 'posix' and 'win32'")
def test_put_whole_calendar_random_uids_used_as_file_names(self): def test_put_whole_calendar_random_uids_used_as_file_names(self):
"""Test if UIDs are used as file names.""" """Test if UIDs are used as file names."""
BaseRequestsMixIn.test_put_whole_calendar_without_uids(self) BaseRequestsMixIn.test_put_whole_calendar_without_uids(self)
@ -1676,8 +1676,8 @@ class TestMultiFileSystem(BaseFileSystemTest, BaseRequestsMixIn):
_, answer = self.get("/calendar.ics/%s.ics" % uid) _, answer = self.get("/calendar.ics/%s.ics" % uid)
assert "\r\nUID:%s\r\n" % uid in answer assert "\r\nUID:%s\r\n" % uid in answer
@pytest.mark.skipif(os.name not in ("nt", "posix"), @pytest.mark.skipif(os.name != "posix" and sys.platform != "win32",
reason="Only supported on 'nt' and 'posix'") reason="Only supported on 'posix' and 'win32'")
def test_put_whole_addressbook_uids_used_as_file_names(self): def test_put_whole_addressbook_uids_used_as_file_names(self):
"""Test if UIDs are used as file names.""" """Test if UIDs are used as file names."""
BaseRequestsMixIn.test_put_whole_addressbook(self) BaseRequestsMixIn.test_put_whole_addressbook(self)
@ -1685,8 +1685,8 @@ class TestMultiFileSystem(BaseFileSystemTest, BaseRequestsMixIn):
_, answer = self.get("/contacts.vcf/%s.vcf" % uid) _, answer = self.get("/contacts.vcf/%s.vcf" % uid)
assert "\r\nUID:%s\r\n" % uid in answer assert "\r\nUID:%s\r\n" % uid in answer
@pytest.mark.skipif(os.name not in ("nt", "posix"), @pytest.mark.skipif(os.name != "posix" and sys.platform != "win32",
reason="Only supported on 'nt' and 'posix'") reason="Only supported on 'posix' and 'win32'")
def test_put_whole_addressbook_random_uids_used_as_file_names(self): def test_put_whole_addressbook_random_uids_used_as_file_names(self):
"""Test if UIDs are used as file names.""" """Test if UIDs are used as file names."""
BaseRequestsMixIn.test_put_whole_addressbook_without_uids(self) BaseRequestsMixIn.test_put_whole_addressbook_without_uids(self)

61
radicale/types.py Normal file
View File

@ -0,0 +1,61 @@
# This file is part of Radicale Server - Calendar Server
# Copyright © 2020 Unrud <unrud@outlook.com>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Radicale. If not, see <http://www.gnu.org/licenses/>.
import contextlib
import sys
from typing import (Any, Callable, ContextManager, Iterator, List, Mapping,
MutableMapping, Sequence, Tuple, TypeVar, Union)
WSGIResponseHeaders = Union[Mapping[str, str], Sequence[Tuple[str, str]]]
WSGIResponse = Tuple[int, WSGIResponseHeaders, Union[None, str, bytes]]
WSGIEnviron = Mapping[str, Any]
WSGIStartResponse = Callable[[str, List[Tuple[str, str]]], Any]
CONFIG = Mapping[str, Mapping[str, Any]]
MUTABLE_CONFIG = MutableMapping[str, MutableMapping[str, Any]]
CONFIG_SCHEMA = Mapping[str, Mapping[str, Any]]
_T = TypeVar("_T")
def contextmanager(func: Callable[..., Iterator[_T]]
) -> Callable[..., ContextManager[_T]]:
"""Compatibility wrapper for `contextlib.contextmanager` with
`typeguard`"""
result = contextlib.contextmanager(func)
result.__annotations__ = {**func.__annotations__,
"return": ContextManager[_T]}
return result
if sys.version_info >= (3, 8):
from typing import Protocol, runtime_checkable
@runtime_checkable
class InputStream(Protocol):
def read(self, size: int = ...) -> bytes: ...
@runtime_checkable
class ErrorStream(Protocol):
def flush(self) -> None: ...
def write(self, s: str) -> None: ...
else:
ErrorStream = Any
InputStream = Any
from radicale import item, storage # noqa:E402 isort:skip
CollectionOrItem = Union[item.Item, storage.BaseCollection]

View File

@ -17,12 +17,18 @@
# along with Radicale. If not, see <http://www.gnu.org/licenses/>. # along with Radicale. If not, see <http://www.gnu.org/licenses/>.
from importlib import import_module from importlib import import_module
from typing import Callable, Sequence, Type, TypeVar, Union
from radicale import config
from radicale.log import logger from radicale.log import logger
_T_co = TypeVar("_T_co", covariant=True)
def load_plugin(internal_types, module_name, class_name, configuration):
type_ = configuration.get(module_name, "type") def load_plugin(internal_types: Sequence[str], module_name: str,
class_name: str, base_class: Type[_T_co],
configuration: "config.Configuration") -> _T_co:
type_: Union[str, Callable] = configuration.get(module_name, "type")
if callable(type_): if callable(type_):
logger.info("%s type is %r", module_name, type_) logger.info("%s type is %r", module_name, type_)
return type_(configuration) return type_(configuration)

View File

@ -21,18 +21,24 @@ Take a look at the class ``BaseWeb`` if you want to implement your own.
""" """
from radicale import httputils, utils from typing import Sequence
INTERNAL_TYPES = ("none", "internal") from radicale import config, httputils, types, utils
INTERNAL_TYPES: Sequence[str] = ("none", "internal")
def load(configuration): def load(configuration: "config.Configuration") -> "BaseWeb":
"""Load the web module chosen in configuration.""" """Load the web module chosen in configuration."""
return utils.load_plugin(INTERNAL_TYPES, "web", "Web", configuration) return utils.load_plugin(INTERNAL_TYPES, "web", "Web", BaseWeb,
configuration)
class BaseWeb: class BaseWeb:
def __init__(self, configuration):
configuration: "config.Configuration"
def __init__(self, configuration: "config.Configuration") -> None:
"""Initialize BaseWeb. """Initialize BaseWeb.
``configuration`` see ``radicale.config`` module. ``configuration`` see ``radicale.config`` module.
@ -42,7 +48,8 @@ class BaseWeb:
""" """
self.configuration = configuration self.configuration = configuration
def get(self, environ, base_prefix, path, user): def get(self, environ: types.WSGIEnviron, base_prefix: str, path: str,
user: str) -> types.WSGIResponse:
"""GET request. """GET request.
``base_prefix`` is sanitized and never ends with "/". ``base_prefix`` is sanitized and never ends with "/".
@ -54,7 +61,8 @@ class BaseWeb:
""" """
return httputils.METHOD_NOT_ALLOWED return httputils.METHOD_NOT_ALLOWED
def post(self, environ, base_prefix, path, user): def post(self, environ: types.WSGIEnviron, base_prefix: str, path: str,
user: str) -> types.WSGIResponse:
"""POST request. """POST request.
``base_prefix`` is sanitized and never ends with "/". ``base_prefix`` is sanitized and never ends with "/".

View File

@ -30,13 +30,14 @@ import os
import posixpath import posixpath
import time import time
from http import client from http import client
from typing import Mapping
import pkg_resources import pkg_resources
from radicale import httputils, pathutils, web from radicale import config, httputils, pathutils, types, web
from radicale.log import logger from radicale.log import logger
MIMETYPES = { MIMETYPES: Mapping[str, str] = {
".css": "text/css", ".css": "text/css",
".eot": "application/vnd.ms-fontobject", ".eot": "application/vnd.ms-fontobject",
".gif": "image/gif", ".gif": "image/gif",
@ -50,16 +51,20 @@ MIMETYPES = {
".woff": "application/font-woff", ".woff": "application/font-woff",
".woff2": "font/woff2", ".woff2": "font/woff2",
".xml": "text/xml"} ".xml": "text/xml"}
FALLBACK_MIMETYPE = "application/octet-stream" FALLBACK_MIMETYPE: str = "application/octet-stream"
class Web(web.BaseWeb): class Web(web.BaseWeb):
def __init__(self, configuration):
folder: str
def __init__(self, configuration: config.Configuration) -> None:
super().__init__(configuration) super().__init__(configuration)
self.folder = pkg_resources.resource_filename(__name__, self.folder = pkg_resources.resource_filename(__name__,
"internal_data") "internal_data")
def get(self, environ, base_prefix, path, user): def get(self, environ: types.WSGIEnviron, base_prefix: str, path: str,
user: str) -> types.WSGIResponse:
assert path == "/.web" or path.startswith("/.web/") assert path == "/.web" or path.startswith("/.web/")
assert pathutils.sanitize_path(path) == path assert pathutils.sanitize_path(path) == path
try: try:

View File

@ -21,11 +21,13 @@ A dummy web backend that shows a simple message.
from http import client from http import client
from radicale import httputils, pathutils, web from radicale import httputils, pathutils, types, web
class Web(web.BaseWeb): class Web(web.BaseWeb):
def get(self, environ, base_prefix, path, user):
def get(self, environ: types.WSGIEnviron, base_prefix: str, path: str,
user: str) -> types.WSGIResponse:
assert path == "/.web" or path.startswith("/.web/") assert path == "/.web" or path.startswith("/.web/")
assert pathutils.sanitize_path(path) == path assert pathutils.sanitize_path(path) == path
if path != "/.web": if path != "/.web":

View File

@ -26,20 +26,21 @@ import copy
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from collections import OrderedDict from collections import OrderedDict
from http import client from http import client
from typing import Dict, Mapping, Optional
from urllib.parse import quote from urllib.parse import quote
from radicale import pathutils from radicale import item, pathutils
MIMETYPES = { MIMETYPES: Mapping[str, str] = {
"VADDRESSBOOK": "text/vcard", "VADDRESSBOOK": "text/vcard",
"VCALENDAR": "text/calendar"} "VCALENDAR": "text/calendar"}
OBJECT_MIMETYPES = { OBJECT_MIMETYPES: Mapping[str, str] = {
"VCARD": "text/vcard", "VCARD": "text/vcard",
"VLIST": "text/x-vlist", "VLIST": "text/x-vlist",
"VCALENDAR": "text/calendar"} "VCALENDAR": "text/calendar"}
NAMESPACES = { NAMESPACES: Mapping[str, str] = {
"C": "urn:ietf:params:xml:ns:caldav", "C": "urn:ietf:params:xml:ns:caldav",
"CR": "urn:ietf:params:xml:ns:carddav", "CR": "urn:ietf:params:xml:ns:carddav",
"D": "DAV:", "D": "DAV:",
@ -48,15 +49,15 @@ NAMESPACES = {
"ME": "http://me.com/_namespace/", "ME": "http://me.com/_namespace/",
"RADICALE": "http://radicale.org/ns/"} "RADICALE": "http://radicale.org/ns/"}
NAMESPACES_REV = {} NAMESPACES_REV: Mapping[str, str] = {v: k for k, v in NAMESPACES.items()}
for short, url in NAMESPACES.items(): for short, url in NAMESPACES.items():
NAMESPACES_REV[url] = short
ET.register_namespace("" if short == "D" else short, url) ET.register_namespace("" if short == "D" else short, url)
def pretty_xml(element): def pretty_xml(element: ET.Element) -> str:
"""Indent an ElementTree ``element`` and its children.""" """Indent an ElementTree ``element`` and its children."""
def pretty_xml_recursive(element, level): def pretty_xml_recursive(element: ET.Element, level: int) -> None:
indent = "\n" + level * " " indent = "\n" + level * " "
if len(element) > 0: if len(element) > 0:
if not (element.text or "").strip(): if not (element.text or "").strip():
@ -74,7 +75,7 @@ def pretty_xml(element):
return '<?xml version="1.0"?>\n%s' % ET.tostring(element, "unicode") return '<?xml version="1.0"?>\n%s' % ET.tostring(element, "unicode")
def make_clark(human_tag): def make_clark(human_tag: str) -> str:
"""Get XML Clark notation from human tag ``human_tag``. """Get XML Clark notation from human tag ``human_tag``.
If ``human_tag`` is already in XML Clark notation it is returned as-is. If ``human_tag`` is already in XML Clark notation it is returned as-is.
@ -88,13 +89,13 @@ def make_clark(human_tag):
ns_prefix, tag = human_tag.split(":", maxsplit=1) ns_prefix, tag = human_tag.split(":", maxsplit=1)
if not ns_prefix or not tag: if not ns_prefix or not tag:
raise ValueError("Invalid XML tag: %r" % human_tag) raise ValueError("Invalid XML tag: %r" % human_tag)
ns = NAMESPACES.get(ns_prefix) ns = NAMESPACES.get(ns_prefix, "")
if not ns: if not ns:
raise ValueError("Unknown XML namespace prefix: %r" % human_tag) raise ValueError("Unknown XML namespace prefix: %r" % human_tag)
return "{%s}%s" % (ns, tag) return "{%s}%s" % (ns, tag)
def make_human_tag(clark_tag): def make_human_tag(clark_tag: str) -> str:
"""Replace known namespaces in XML Clark notation ``clark_tag`` with """Replace known namespaces in XML Clark notation ``clark_tag`` with
prefix. prefix.
@ -111,31 +112,31 @@ def make_human_tag(clark_tag):
ns, tag = clark_tag[len("{"):].split("}", maxsplit=1) ns, tag = clark_tag[len("{"):].split("}", maxsplit=1)
if not ns or not tag: if not ns or not tag:
raise ValueError("Invalid XML tag: %r" % clark_tag) raise ValueError("Invalid XML tag: %r" % clark_tag)
ns_prefix = NAMESPACES_REV.get(ns) ns_prefix = NAMESPACES_REV.get(ns, "")
if ns_prefix: if ns_prefix:
return "%s:%s" % (ns_prefix, tag) return "%s:%s" % (ns_prefix, tag)
return clark_tag return clark_tag
def make_response(code): def make_response(code: int) -> str:
"""Return full W3C names from HTTP status codes.""" """Return full W3C names from HTTP status codes."""
return "HTTP/1.1 %i %s" % (code, client.responses[code]) return "HTTP/1.1 %i %s" % (code, client.responses[code])
def make_href(base_prefix, href): def make_href(base_prefix: str, href: str) -> str:
"""Return prefixed href.""" """Return prefixed href."""
assert href == pathutils.sanitize_path(href) assert href == pathutils.sanitize_path(href)
return quote("%s%s" % (base_prefix, href)) return quote("%s%s" % (base_prefix, href))
def webdav_error(human_tag): def webdav_error(human_tag: str) -> ET.Element:
"""Generate XML error message.""" """Generate XML error message."""
root = ET.Element(make_clark("D:error")) root = ET.Element(make_clark("D:error"))
root.append(ET.Element(make_clark(human_tag))) root.append(ET.Element(make_clark(human_tag)))
return root return root
def get_content_type(item, encoding): def get_content_type(item: "item.Item", encoding: str) -> str:
"""Get the content-type of an item with charset and component parameters. """Get the content-type of an item with charset and component parameters.
""" """
mimetype = OBJECT_MIMETYPES[item.name] mimetype = OBJECT_MIMETYPES[item.name]
@ -146,13 +147,14 @@ def get_content_type(item, encoding):
return content_type return content_type
def props_from_request(xml_request): def props_from_request(xml_request: Optional[ET.Element]
) -> Dict[str, Optional[str]]:
"""Return a list of properties as a dictionary. """Return a list of properties as a dictionary.
Properties that should be removed are set to `None`. Properties that should be removed are set to `None`.
""" """
result = OrderedDict() result: OrderedDict = OrderedDict()
if xml_request is None: if xml_request is None:
return result return result