Implement reusable FileSinkToken for OutputDirectory

This commit is contained in:
Joscha 2021-05-19 17:16:23 +02:00
parent b7a999bc2e
commit a7c025fd86
2 changed files with 112 additions and 35 deletions

View File

@ -3,19 +3,19 @@ import os
import random
import shutil
import string
from contextlib import asynccontextmanager, contextmanager
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path, PurePath
# TODO In Python 3.9 and above, AsyncContextManager is deprecated
from typing import AsyncContextManager, AsyncIterator, BinaryIO, Iterator, Optional
from typing import AsyncContextManager, BinaryIO, Iterator, Optional, Tuple
from rich.markup import escape
from .logging import log
from .report import MarkConflictException, MarkDuplicateException, Report
from .utils import prompt_yes_no
from .utils import ReusableAsyncContextManager, prompt_yes_no
SUFFIX_CHARS = string.ascii_lowercase + string.digits
SUFFIX_LENGTH = 6
@ -87,6 +87,49 @@ class DownloadInfo:
success: bool = False
class FileSinkToken(ReusableAsyncContextManager[FileSink]):
# Whenever this class is entered, it creates a new temporary file and
# returns a corresponding FileSink.
#
# When it is exited again, the file is closed and information about the
# download handed back to the OutputDirectory.
def __init__(
self,
output_dir: "OutputDirectory",
path: PurePath,
local_path: Path,
heuristics: Heuristics,
on_conflict: OnConflict,
):
super().__init__()
self._output_dir = output_dir
self._path = path
self._local_path = local_path
self._heuristics = heuristics
self._on_conflict = on_conflict
async def _on_aenter(self) -> FileSink:
tmp_path, file = await self._output_dir._create_tmp_file(self._local_path)
sink = FileSink(file)
async def after_download() -> None:
await self._output_dir._after_download(DownloadInfo(
self._path,
self._local_path,
tmp_path,
self._heuristics,
self._on_conflict,
sink.is_done(),
))
self._stack.push_async_callback(after_download)
self._stack.enter_context(file)
return sink
class OutputDirectory:
def __init__(
self,
@ -111,11 +154,9 @@ class OutputDirectory:
try:
self._report.mark(path)
except MarkDuplicateException:
msg = "Another file has already been placed here."
raise OutputDirException(msg)
raise OutputDirException("Another file has already been placed here.")
except MarkConflictException as e:
msg = f"Collides with other file: {e.collides_with}"
raise OutputDirException(msg)
raise OutputDirException(f"Collides with other file: {e.collides_with}")
def resolve(self, path: PurePath) -> Path:
"""
@ -123,8 +164,7 @@ class OutputDirectory:
"""
if ".." in path.parts:
msg = f"Path {path} contains forbidden '..'"
raise OutputDirException(msg)
raise OutputDirException(f"Path {path} contains forbidden '..'")
return self._root / path
def _should_download(
@ -137,6 +177,7 @@ class OutputDirectory:
# since we know that the remote is different from the local files. This
# includes the case where no local file exists.
if not local_path.is_file():
# TODO Don't download if on_conflict is LOCAL_FIRST or NO_DELETE
return True
if redownload == Redownload.NEVER:
@ -251,19 +292,24 @@ class OutputDirectory:
name = f"{prefix}{base.name}.tmp.{suffix}"
return base.parent / name
@asynccontextmanager
async def _sink_context_manager(
async def _create_tmp_file(
self,
file: BinaryIO,
info: DownloadInfo,
) -> AsyncIterator[FileSink]:
sink = FileSink(file)
local_path: Path,
) -> Tuple[Path, BinaryIO]:
"""
May raise an OutputDirException.
"""
# Create tmp file
for attempt in range(TRIES):
suffix_length = SUFFIX_LENGTH + 2 * attempt
tmp_path = self._tmp_path(local_path, suffix_length)
try:
with file:
yield sink
finally:
info.success = sink.is_done()
await self._after_download(info)
return tmp_path, open(tmp_path, "xb")
except FileExistsError:
pass # Try again
raise OutputDirException(f"Failed to create temporary file {tmp_path}")
async def download(
self,
@ -306,19 +352,7 @@ class OutputDirectory:
# Ensure parent directory exists
local_path.parent.mkdir(parents=True, exist_ok=True)
# Create tmp file
for attempt in range(TRIES):
suffix_length = SUFFIX_LENGTH + 2 * attempt
tmp_path = self._tmp_path(local_path, suffix_length)
info = DownloadInfo(path, local_path, tmp_path,
heuristics, on_conflict)
try:
file = open(tmp_path, "xb")
return self._sink_context_manager(file, info)
except FileExistsError:
pass # Try again
return None
return FileSinkToken(self, path, local_path, heuristics, on_conflict)
def _update_metadata(self, info: DownloadInfo) -> None:
if mtime := info.heuristics.mtime:

View File

@ -2,7 +2,11 @@ import asyncio
import contextvars
import functools
import getpass
from typing import Any, Callable, Optional, TypeVar
import sys
from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Any, Callable, Generic, Optional, Type, TypeVar
import bs4
@ -56,3 +60,42 @@ async def prompt_yes_no(query: str, default: Optional[bool]) -> bool:
return default
print("Please answer with 'y' or 'n'.")
class ReusableAsyncContextManager(ABC, Generic[T]):
def __init__(self) -> None:
self._active = False
self._stack = AsyncExitStack()
@abstractmethod
async def _on_aenter(self) -> T:
pass
async def __aenter__(self) -> T:
if self._active:
raise RuntimeError("Nested or otherwise concurrent usage is not allowed")
self._active = True
await self._stack.__aenter__()
# See https://stackoverflow.com/a/13075071
try:
result: T = await self._on_aenter()
except: # noqa: E722 do not use bare 'except'
if not await self.__aexit__(*sys.exc_info()):
raise
return result
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
if not self._active:
raise RuntimeError("__aexit__ called too many times")
result = await self._stack.__aexit__(exc_type, exc_value, traceback)
self._active = False
return result