From a7c025fd866132a7c5fd87684c2e56b951b1460e Mon Sep 17 00:00:00 2001 From: Joscha Date: Wed, 19 May 2021 17:16:23 +0200 Subject: [PATCH] Implement reusable FileSinkToken for OutputDirectory --- PFERD/output_dir.py | 102 +++++++++++++++++++++++++++++--------------- PFERD/utils.py | 45 ++++++++++++++++++- 2 files changed, 112 insertions(+), 35 deletions(-) diff --git a/PFERD/output_dir.py b/PFERD/output_dir.py index 417fa52..783d6bc 100644 --- a/PFERD/output_dir.py +++ b/PFERD/output_dir.py @@ -3,19 +3,19 @@ import os import random import shutil import string -from contextlib import asynccontextmanager, contextmanager +from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from enum import Enum from pathlib import Path, PurePath # TODO In Python 3.9 and above, AsyncContextManager is deprecated -from typing import AsyncContextManager, AsyncIterator, BinaryIO, Iterator, Optional +from typing import AsyncContextManager, BinaryIO, Iterator, Optional, Tuple from rich.markup import escape from .logging import log from .report import MarkConflictException, MarkDuplicateException, Report -from .utils import prompt_yes_no +from .utils import ReusableAsyncContextManager, prompt_yes_no SUFFIX_CHARS = string.ascii_lowercase + string.digits SUFFIX_LENGTH = 6 @@ -87,6 +87,49 @@ class DownloadInfo: success: bool = False +class FileSinkToken(ReusableAsyncContextManager[FileSink]): + # Whenever this class is entered, it creates a new temporary file and + # returns a corresponding FileSink. + # + # When it is exited again, the file is closed and information about the + # download handed back to the OutputDirectory. + + def __init__( + self, + output_dir: "OutputDirectory", + path: PurePath, + local_path: Path, + heuristics: Heuristics, + on_conflict: OnConflict, + ): + super().__init__() + + self._output_dir = output_dir + self._path = path + self._local_path = local_path + self._heuristics = heuristics + self._on_conflict = on_conflict + + async def _on_aenter(self) -> FileSink: + tmp_path, file = await self._output_dir._create_tmp_file(self._local_path) + sink = FileSink(file) + + async def after_download() -> None: + await self._output_dir._after_download(DownloadInfo( + self._path, + self._local_path, + tmp_path, + self._heuristics, + self._on_conflict, + sink.is_done(), + )) + + self._stack.push_async_callback(after_download) + self._stack.enter_context(file) + + return sink + + class OutputDirectory: def __init__( self, @@ -111,11 +154,9 @@ class OutputDirectory: try: self._report.mark(path) except MarkDuplicateException: - msg = "Another file has already been placed here." - raise OutputDirException(msg) + raise OutputDirException("Another file has already been placed here.") except MarkConflictException as e: - msg = f"Collides with other file: {e.collides_with}" - raise OutputDirException(msg) + raise OutputDirException(f"Collides with other file: {e.collides_with}") def resolve(self, path: PurePath) -> Path: """ @@ -123,8 +164,7 @@ class OutputDirectory: """ if ".." in path.parts: - msg = f"Path {path} contains forbidden '..'" - raise OutputDirException(msg) + raise OutputDirException(f"Path {path} contains forbidden '..'") return self._root / path def _should_download( @@ -137,6 +177,7 @@ class OutputDirectory: # since we know that the remote is different from the local files. This # includes the case where no local file exists. if not local_path.is_file(): + # TODO Don't download if on_conflict is LOCAL_FIRST or NO_DELETE return True if redownload == Redownload.NEVER: @@ -251,19 +292,24 @@ class OutputDirectory: name = f"{prefix}{base.name}.tmp.{suffix}" return base.parent / name - @asynccontextmanager - async def _sink_context_manager( + async def _create_tmp_file( self, - file: BinaryIO, - info: DownloadInfo, - ) -> AsyncIterator[FileSink]: - sink = FileSink(file) - try: - with file: - yield sink - finally: - info.success = sink.is_done() - await self._after_download(info) + local_path: Path, + ) -> Tuple[Path, BinaryIO]: + """ + May raise an OutputDirException. + """ + + # Create tmp file + for attempt in range(TRIES): + suffix_length = SUFFIX_LENGTH + 2 * attempt + tmp_path = self._tmp_path(local_path, suffix_length) + try: + return tmp_path, open(tmp_path, "xb") + except FileExistsError: + pass # Try again + + raise OutputDirException(f"Failed to create temporary file {tmp_path}") async def download( self, @@ -306,19 +352,7 @@ class OutputDirectory: # Ensure parent directory exists local_path.parent.mkdir(parents=True, exist_ok=True) - # Create tmp file - for attempt in range(TRIES): - suffix_length = SUFFIX_LENGTH + 2 * attempt - tmp_path = self._tmp_path(local_path, suffix_length) - info = DownloadInfo(path, local_path, tmp_path, - heuristics, on_conflict) - try: - file = open(tmp_path, "xb") - return self._sink_context_manager(file, info) - except FileExistsError: - pass # Try again - - return None + return FileSinkToken(self, path, local_path, heuristics, on_conflict) def _update_metadata(self, info: DownloadInfo) -> None: if mtime := info.heuristics.mtime: diff --git a/PFERD/utils.py b/PFERD/utils.py index 3022ab6..0b3d40d 100644 --- a/PFERD/utils.py +++ b/PFERD/utils.py @@ -2,7 +2,11 @@ import asyncio import contextvars import functools import getpass -from typing import Any, Callable, Optional, TypeVar +import sys +from abc import ABC, abstractmethod +from contextlib import AsyncExitStack +from types import TracebackType +from typing import Any, Callable, Generic, Optional, Type, TypeVar import bs4 @@ -56,3 +60,42 @@ async def prompt_yes_no(query: str, default: Optional[bool]) -> bool: return default print("Please answer with 'y' or 'n'.") + + +class ReusableAsyncContextManager(ABC, Generic[T]): + def __init__(self) -> None: + self._active = False + self._stack = AsyncExitStack() + + @abstractmethod + async def _on_aenter(self) -> T: + pass + + async def __aenter__(self) -> T: + if self._active: + raise RuntimeError("Nested or otherwise concurrent usage is not allowed") + + self._active = True + await self._stack.__aenter__() + + # See https://stackoverflow.com/a/13075071 + try: + result: T = await self._on_aenter() + except: # noqa: E722 do not use bare 'except' + if not await self.__aexit__(*sys.exc_info()): + raise + + return result + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + if not self._active: + raise RuntimeError("__aexit__ called too many times") + + result = await self._stack.__aexit__(exc_type, exc_value, traceback) + self._active = False + return result