Implement reusable FileSinkToken for OutputDirectory

This commit is contained in:
Joscha 2021-05-19 17:16:23 +02:00
parent b7a999bc2e
commit a7c025fd86
2 changed files with 112 additions and 35 deletions

View File

@ -3,19 +3,19 @@ import os
import random import random
import shutil import shutil
import string import string
from contextlib import asynccontextmanager, contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from pathlib import Path, PurePath from pathlib import Path, PurePath
# TODO In Python 3.9 and above, AsyncContextManager is deprecated # TODO In Python 3.9 and above, AsyncContextManager is deprecated
from typing import AsyncContextManager, AsyncIterator, BinaryIO, Iterator, Optional from typing import AsyncContextManager, BinaryIO, Iterator, Optional, Tuple
from rich.markup import escape from rich.markup import escape
from .logging import log from .logging import log
from .report import MarkConflictException, MarkDuplicateException, Report from .report import MarkConflictException, MarkDuplicateException, Report
from .utils import prompt_yes_no from .utils import ReusableAsyncContextManager, prompt_yes_no
SUFFIX_CHARS = string.ascii_lowercase + string.digits SUFFIX_CHARS = string.ascii_lowercase + string.digits
SUFFIX_LENGTH = 6 SUFFIX_LENGTH = 6
@ -87,6 +87,49 @@ class DownloadInfo:
success: bool = False success: bool = False
class FileSinkToken(ReusableAsyncContextManager[FileSink]):
# Whenever this class is entered, it creates a new temporary file and
# returns a corresponding FileSink.
#
# When it is exited again, the file is closed and information about the
# download handed back to the OutputDirectory.
def __init__(
self,
output_dir: "OutputDirectory",
path: PurePath,
local_path: Path,
heuristics: Heuristics,
on_conflict: OnConflict,
):
super().__init__()
self._output_dir = output_dir
self._path = path
self._local_path = local_path
self._heuristics = heuristics
self._on_conflict = on_conflict
async def _on_aenter(self) -> FileSink:
tmp_path, file = await self._output_dir._create_tmp_file(self._local_path)
sink = FileSink(file)
async def after_download() -> None:
await self._output_dir._after_download(DownloadInfo(
self._path,
self._local_path,
tmp_path,
self._heuristics,
self._on_conflict,
sink.is_done(),
))
self._stack.push_async_callback(after_download)
self._stack.enter_context(file)
return sink
class OutputDirectory: class OutputDirectory:
def __init__( def __init__(
self, self,
@ -111,11 +154,9 @@ class OutputDirectory:
try: try:
self._report.mark(path) self._report.mark(path)
except MarkDuplicateException: except MarkDuplicateException:
msg = "Another file has already been placed here." raise OutputDirException("Another file has already been placed here.")
raise OutputDirException(msg)
except MarkConflictException as e: except MarkConflictException as e:
msg = f"Collides with other file: {e.collides_with}" raise OutputDirException(f"Collides with other file: {e.collides_with}")
raise OutputDirException(msg)
def resolve(self, path: PurePath) -> Path: def resolve(self, path: PurePath) -> Path:
""" """
@ -123,8 +164,7 @@ class OutputDirectory:
""" """
if ".." in path.parts: if ".." in path.parts:
msg = f"Path {path} contains forbidden '..'" raise OutputDirException(f"Path {path} contains forbidden '..'")
raise OutputDirException(msg)
return self._root / path return self._root / path
def _should_download( def _should_download(
@ -137,6 +177,7 @@ class OutputDirectory:
# since we know that the remote is different from the local files. This # since we know that the remote is different from the local files. This
# includes the case where no local file exists. # includes the case where no local file exists.
if not local_path.is_file(): if not local_path.is_file():
# TODO Don't download if on_conflict is LOCAL_FIRST or NO_DELETE
return True return True
if redownload == Redownload.NEVER: if redownload == Redownload.NEVER:
@ -251,19 +292,24 @@ class OutputDirectory:
name = f"{prefix}{base.name}.tmp.{suffix}" name = f"{prefix}{base.name}.tmp.{suffix}"
return base.parent / name return base.parent / name
@asynccontextmanager async def _create_tmp_file(
async def _sink_context_manager(
self, self,
file: BinaryIO, local_path: Path,
info: DownloadInfo, ) -> Tuple[Path, BinaryIO]:
) -> AsyncIterator[FileSink]: """
sink = FileSink(file) May raise an OutputDirException.
try: """
with file:
yield sink # Create tmp file
finally: for attempt in range(TRIES):
info.success = sink.is_done() suffix_length = SUFFIX_LENGTH + 2 * attempt
await self._after_download(info) tmp_path = self._tmp_path(local_path, suffix_length)
try:
return tmp_path, open(tmp_path, "xb")
except FileExistsError:
pass # Try again
raise OutputDirException(f"Failed to create temporary file {tmp_path}")
async def download( async def download(
self, self,
@ -306,19 +352,7 @@ class OutputDirectory:
# Ensure parent directory exists # Ensure parent directory exists
local_path.parent.mkdir(parents=True, exist_ok=True) local_path.parent.mkdir(parents=True, exist_ok=True)
# Create tmp file return FileSinkToken(self, path, local_path, heuristics, on_conflict)
for attempt in range(TRIES):
suffix_length = SUFFIX_LENGTH + 2 * attempt
tmp_path = self._tmp_path(local_path, suffix_length)
info = DownloadInfo(path, local_path, tmp_path,
heuristics, on_conflict)
try:
file = open(tmp_path, "xb")
return self._sink_context_manager(file, info)
except FileExistsError:
pass # Try again
return None
def _update_metadata(self, info: DownloadInfo) -> None: def _update_metadata(self, info: DownloadInfo) -> None:
if mtime := info.heuristics.mtime: if mtime := info.heuristics.mtime:

View File

@ -2,7 +2,11 @@ import asyncio
import contextvars import contextvars
import functools import functools
import getpass import getpass
from typing import Any, Callable, Optional, TypeVar import sys
from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Any, Callable, Generic, Optional, Type, TypeVar
import bs4 import bs4
@ -56,3 +60,42 @@ async def prompt_yes_no(query: str, default: Optional[bool]) -> bool:
return default return default
print("Please answer with 'y' or 'n'.") print("Please answer with 'y' or 'n'.")
class ReusableAsyncContextManager(ABC, Generic[T]):
def __init__(self) -> None:
self._active = False
self._stack = AsyncExitStack()
@abstractmethod
async def _on_aenter(self) -> T:
pass
async def __aenter__(self) -> T:
if self._active:
raise RuntimeError("Nested or otherwise concurrent usage is not allowed")
self._active = True
await self._stack.__aenter__()
# See https://stackoverflow.com/a/13075071
try:
result: T = await self._on_aenter()
except: # noqa: E722 do not use bare 'except'
if not await self.__aexit__(*sys.exc_info()):
raise
return result
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
if not self._active:
raise RuntimeError("__aexit__ called too many times")
result = await self._stack.__aexit__(exc_type, exc_value, traceback)
self._active = False
return result