mirror of
				https://github.com/Garmelon/PFERD.git
				synced 2025-10-31 21:02:42 +01:00 
			
		
		
		
	Implement reusable FileSinkToken for OutputDirectory
This commit is contained in:
		| @@ -3,19 +3,19 @@ import os | ||||
| import random | ||||
| import shutil | ||||
| import string | ||||
| from contextlib import asynccontextmanager, contextmanager | ||||
| from contextlib import contextmanager | ||||
| from dataclasses import dataclass | ||||
| from datetime import datetime | ||||
| from enum import Enum | ||||
| from pathlib import Path, PurePath | ||||
| # TODO In Python 3.9 and above, AsyncContextManager is deprecated | ||||
| from typing import AsyncContextManager, AsyncIterator, BinaryIO, Iterator, Optional | ||||
| from typing import AsyncContextManager, BinaryIO, Iterator, Optional, Tuple | ||||
|  | ||||
| from rich.markup import escape | ||||
|  | ||||
| from .logging import log | ||||
| from .report import MarkConflictException, MarkDuplicateException, Report | ||||
| from .utils import prompt_yes_no | ||||
| from .utils import ReusableAsyncContextManager, prompt_yes_no | ||||
|  | ||||
| SUFFIX_CHARS = string.ascii_lowercase + string.digits | ||||
| SUFFIX_LENGTH = 6 | ||||
| @@ -87,6 +87,49 @@ class DownloadInfo: | ||||
|     success: bool = False | ||||
|  | ||||
|  | ||||
| class FileSinkToken(ReusableAsyncContextManager[FileSink]): | ||||
|     # Whenever this class is entered, it creates a new temporary file and | ||||
|     # returns a corresponding FileSink. | ||||
|     # | ||||
|     # When it is exited again, the file is closed and information about the | ||||
|     # download handed back to the OutputDirectory. | ||||
|  | ||||
|     def __init__( | ||||
|             self, | ||||
|             output_dir: "OutputDirectory", | ||||
|             path: PurePath, | ||||
|             local_path: Path, | ||||
|             heuristics: Heuristics, | ||||
|             on_conflict: OnConflict, | ||||
|     ): | ||||
|         super().__init__() | ||||
|  | ||||
|         self._output_dir = output_dir | ||||
|         self._path = path | ||||
|         self._local_path = local_path | ||||
|         self._heuristics = heuristics | ||||
|         self._on_conflict = on_conflict | ||||
|  | ||||
|     async def _on_aenter(self) -> FileSink: | ||||
|         tmp_path, file = await self._output_dir._create_tmp_file(self._local_path) | ||||
|         sink = FileSink(file) | ||||
|  | ||||
|         async def after_download() -> None: | ||||
|             await self._output_dir._after_download(DownloadInfo( | ||||
|                 self._path, | ||||
|                 self._local_path, | ||||
|                 tmp_path, | ||||
|                 self._heuristics, | ||||
|                 self._on_conflict, | ||||
|                 sink.is_done(), | ||||
|             )) | ||||
|  | ||||
|         self._stack.push_async_callback(after_download) | ||||
|         self._stack.enter_context(file) | ||||
|  | ||||
|         return sink | ||||
|  | ||||
|  | ||||
| class OutputDirectory: | ||||
|     def __init__( | ||||
|             self, | ||||
| @@ -111,11 +154,9 @@ class OutputDirectory: | ||||
|         try: | ||||
|             self._report.mark(path) | ||||
|         except MarkDuplicateException: | ||||
|             msg = "Another file has already been placed here." | ||||
|             raise OutputDirException(msg) | ||||
|             raise OutputDirException("Another file has already been placed here.") | ||||
|         except MarkConflictException as e: | ||||
|             msg = f"Collides with other file: {e.collides_with}" | ||||
|             raise OutputDirException(msg) | ||||
|             raise OutputDirException(f"Collides with other file: {e.collides_with}") | ||||
|  | ||||
|     def resolve(self, path: PurePath) -> Path: | ||||
|         """ | ||||
| @@ -123,8 +164,7 @@ class OutputDirectory: | ||||
|         """ | ||||
|  | ||||
|         if ".." in path.parts: | ||||
|             msg = f"Path {path} contains forbidden '..'" | ||||
|             raise OutputDirException(msg) | ||||
|             raise OutputDirException(f"Path {path} contains forbidden '..'") | ||||
|         return self._root / path | ||||
|  | ||||
|     def _should_download( | ||||
| @@ -137,6 +177,7 @@ class OutputDirectory: | ||||
|         # since we know that the remote is different from the local files. This | ||||
|         # includes the case where no local file exists. | ||||
|         if not local_path.is_file(): | ||||
|             # TODO Don't download if on_conflict is LOCAL_FIRST or NO_DELETE | ||||
|             return True | ||||
|  | ||||
|         if redownload == Redownload.NEVER: | ||||
| @@ -251,19 +292,24 @@ class OutputDirectory: | ||||
|         name = f"{prefix}{base.name}.tmp.{suffix}" | ||||
|         return base.parent / name | ||||
|  | ||||
|     @asynccontextmanager | ||||
|     async def _sink_context_manager( | ||||
|     async def _create_tmp_file( | ||||
|             self, | ||||
|             file: BinaryIO, | ||||
|             info: DownloadInfo, | ||||
|     ) -> AsyncIterator[FileSink]: | ||||
|         sink = FileSink(file) | ||||
|         try: | ||||
|             with file: | ||||
|                 yield sink | ||||
|         finally: | ||||
|             info.success = sink.is_done() | ||||
|             await self._after_download(info) | ||||
|             local_path: Path, | ||||
|     ) -> Tuple[Path, BinaryIO]: | ||||
|         """ | ||||
|         May raise an OutputDirException. | ||||
|         """ | ||||
|  | ||||
|         # Create tmp file | ||||
|         for attempt in range(TRIES): | ||||
|             suffix_length = SUFFIX_LENGTH + 2 * attempt | ||||
|             tmp_path = self._tmp_path(local_path, suffix_length) | ||||
|             try: | ||||
|                 return tmp_path, open(tmp_path, "xb") | ||||
|             except FileExistsError: | ||||
|                 pass  # Try again | ||||
|  | ||||
|         raise OutputDirException(f"Failed to create temporary file {tmp_path}") | ||||
|  | ||||
|     async def download( | ||||
|             self, | ||||
| @@ -306,19 +352,7 @@ class OutputDirectory: | ||||
|         # Ensure parent directory exists | ||||
|         local_path.parent.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|         # Create tmp file | ||||
|         for attempt in range(TRIES): | ||||
|             suffix_length = SUFFIX_LENGTH + 2 * attempt | ||||
|             tmp_path = self._tmp_path(local_path, suffix_length) | ||||
|             info = DownloadInfo(path, local_path, tmp_path, | ||||
|                                 heuristics, on_conflict) | ||||
|             try: | ||||
|                 file = open(tmp_path, "xb") | ||||
|                 return self._sink_context_manager(file, info) | ||||
|             except FileExistsError: | ||||
|                 pass  # Try again | ||||
|  | ||||
|         return None | ||||
|         return FileSinkToken(self, path, local_path, heuristics, on_conflict) | ||||
|  | ||||
|     def _update_metadata(self, info: DownloadInfo) -> None: | ||||
|         if mtime := info.heuristics.mtime: | ||||
|   | ||||
| @@ -2,7 +2,11 @@ import asyncio | ||||
| import contextvars | ||||
| import functools | ||||
| import getpass | ||||
| from typing import Any, Callable, Optional, TypeVar | ||||
| import sys | ||||
| from abc import ABC, abstractmethod | ||||
| from contextlib import AsyncExitStack | ||||
| from types import TracebackType | ||||
| from typing import Any, Callable, Generic, Optional, Type, TypeVar | ||||
|  | ||||
| import bs4 | ||||
|  | ||||
| @@ -56,3 +60,42 @@ async def prompt_yes_no(query: str, default: Optional[bool]) -> bool: | ||||
|             return default | ||||
|  | ||||
|         print("Please answer with 'y' or 'n'.") | ||||
|  | ||||
|  | ||||
| class ReusableAsyncContextManager(ABC, Generic[T]): | ||||
|     def __init__(self) -> None: | ||||
|         self._active = False | ||||
|         self._stack = AsyncExitStack() | ||||
|  | ||||
|     @abstractmethod | ||||
|     async def _on_aenter(self) -> T: | ||||
|         pass | ||||
|  | ||||
|     async def __aenter__(self) -> T: | ||||
|         if self._active: | ||||
|             raise RuntimeError("Nested or otherwise concurrent usage is not allowed") | ||||
|  | ||||
|         self._active = True | ||||
|         await self._stack.__aenter__() | ||||
|  | ||||
|         # See https://stackoverflow.com/a/13075071 | ||||
|         try: | ||||
|             result: T = await self._on_aenter() | ||||
|         except:  # noqa: E722 do not use bare 'except' | ||||
|             if not await self.__aexit__(*sys.exc_info()): | ||||
|                 raise | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     async def __aexit__( | ||||
|             self, | ||||
|             exc_type: Optional[Type[BaseException]], | ||||
|             exc_value: Optional[BaseException], | ||||
|             traceback: Optional[TracebackType], | ||||
|     ) -> Optional[bool]: | ||||
|         if not self._active: | ||||
|             raise RuntimeError("__aexit__ called too many times") | ||||
|  | ||||
|         result = await self._stack.__aexit__(exc_type, exc_value, traceback) | ||||
|         self._active = False | ||||
|         return result | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Joscha
					Joscha