mirror of
https://github.com/Garmelon/PFERD.git
synced 2023-12-21 10:23:01 +01:00
Implement reusable FileSinkToken for OutputDirectory
This commit is contained in:
parent
b7a999bc2e
commit
a7c025fd86
@ -3,19 +3,19 @@ import os
|
|||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import string
|
import string
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path, PurePath
|
from pathlib import Path, PurePath
|
||||||
# TODO In Python 3.9 and above, AsyncContextManager is deprecated
|
# TODO In Python 3.9 and above, AsyncContextManager is deprecated
|
||||||
from typing import AsyncContextManager, AsyncIterator, BinaryIO, Iterator, Optional
|
from typing import AsyncContextManager, BinaryIO, Iterator, Optional, Tuple
|
||||||
|
|
||||||
from rich.markup import escape
|
from rich.markup import escape
|
||||||
|
|
||||||
from .logging import log
|
from .logging import log
|
||||||
from .report import MarkConflictException, MarkDuplicateException, Report
|
from .report import MarkConflictException, MarkDuplicateException, Report
|
||||||
from .utils import prompt_yes_no
|
from .utils import ReusableAsyncContextManager, prompt_yes_no
|
||||||
|
|
||||||
SUFFIX_CHARS = string.ascii_lowercase + string.digits
|
SUFFIX_CHARS = string.ascii_lowercase + string.digits
|
||||||
SUFFIX_LENGTH = 6
|
SUFFIX_LENGTH = 6
|
||||||
@ -87,6 +87,49 @@ class DownloadInfo:
|
|||||||
success: bool = False
|
success: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class FileSinkToken(ReusableAsyncContextManager[FileSink]):
|
||||||
|
# Whenever this class is entered, it creates a new temporary file and
|
||||||
|
# returns a corresponding FileSink.
|
||||||
|
#
|
||||||
|
# When it is exited again, the file is closed and information about the
|
||||||
|
# download handed back to the OutputDirectory.
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_dir: "OutputDirectory",
|
||||||
|
path: PurePath,
|
||||||
|
local_path: Path,
|
||||||
|
heuristics: Heuristics,
|
||||||
|
on_conflict: OnConflict,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._output_dir = output_dir
|
||||||
|
self._path = path
|
||||||
|
self._local_path = local_path
|
||||||
|
self._heuristics = heuristics
|
||||||
|
self._on_conflict = on_conflict
|
||||||
|
|
||||||
|
async def _on_aenter(self) -> FileSink:
|
||||||
|
tmp_path, file = await self._output_dir._create_tmp_file(self._local_path)
|
||||||
|
sink = FileSink(file)
|
||||||
|
|
||||||
|
async def after_download() -> None:
|
||||||
|
await self._output_dir._after_download(DownloadInfo(
|
||||||
|
self._path,
|
||||||
|
self._local_path,
|
||||||
|
tmp_path,
|
||||||
|
self._heuristics,
|
||||||
|
self._on_conflict,
|
||||||
|
sink.is_done(),
|
||||||
|
))
|
||||||
|
|
||||||
|
self._stack.push_async_callback(after_download)
|
||||||
|
self._stack.enter_context(file)
|
||||||
|
|
||||||
|
return sink
|
||||||
|
|
||||||
|
|
||||||
class OutputDirectory:
|
class OutputDirectory:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -111,11 +154,9 @@ class OutputDirectory:
|
|||||||
try:
|
try:
|
||||||
self._report.mark(path)
|
self._report.mark(path)
|
||||||
except MarkDuplicateException:
|
except MarkDuplicateException:
|
||||||
msg = "Another file has already been placed here."
|
raise OutputDirException("Another file has already been placed here.")
|
||||||
raise OutputDirException(msg)
|
|
||||||
except MarkConflictException as e:
|
except MarkConflictException as e:
|
||||||
msg = f"Collides with other file: {e.collides_with}"
|
raise OutputDirException(f"Collides with other file: {e.collides_with}")
|
||||||
raise OutputDirException(msg)
|
|
||||||
|
|
||||||
def resolve(self, path: PurePath) -> Path:
|
def resolve(self, path: PurePath) -> Path:
|
||||||
"""
|
"""
|
||||||
@ -123,8 +164,7 @@ class OutputDirectory:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if ".." in path.parts:
|
if ".." in path.parts:
|
||||||
msg = f"Path {path} contains forbidden '..'"
|
raise OutputDirException(f"Path {path} contains forbidden '..'")
|
||||||
raise OutputDirException(msg)
|
|
||||||
return self._root / path
|
return self._root / path
|
||||||
|
|
||||||
def _should_download(
|
def _should_download(
|
||||||
@ -137,6 +177,7 @@ class OutputDirectory:
|
|||||||
# since we know that the remote is different from the local files. This
|
# since we know that the remote is different from the local files. This
|
||||||
# includes the case where no local file exists.
|
# includes the case where no local file exists.
|
||||||
if not local_path.is_file():
|
if not local_path.is_file():
|
||||||
|
# TODO Don't download if on_conflict is LOCAL_FIRST or NO_DELETE
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if redownload == Redownload.NEVER:
|
if redownload == Redownload.NEVER:
|
||||||
@ -251,19 +292,24 @@ class OutputDirectory:
|
|||||||
name = f"{prefix}{base.name}.tmp.{suffix}"
|
name = f"{prefix}{base.name}.tmp.{suffix}"
|
||||||
return base.parent / name
|
return base.parent / name
|
||||||
|
|
||||||
@asynccontextmanager
|
async def _create_tmp_file(
|
||||||
async def _sink_context_manager(
|
|
||||||
self,
|
self,
|
||||||
file: BinaryIO,
|
local_path: Path,
|
||||||
info: DownloadInfo,
|
) -> Tuple[Path, BinaryIO]:
|
||||||
) -> AsyncIterator[FileSink]:
|
"""
|
||||||
sink = FileSink(file)
|
May raise an OutputDirException.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create tmp file
|
||||||
|
for attempt in range(TRIES):
|
||||||
|
suffix_length = SUFFIX_LENGTH + 2 * attempt
|
||||||
|
tmp_path = self._tmp_path(local_path, suffix_length)
|
||||||
try:
|
try:
|
||||||
with file:
|
return tmp_path, open(tmp_path, "xb")
|
||||||
yield sink
|
except FileExistsError:
|
||||||
finally:
|
pass # Try again
|
||||||
info.success = sink.is_done()
|
|
||||||
await self._after_download(info)
|
raise OutputDirException(f"Failed to create temporary file {tmp_path}")
|
||||||
|
|
||||||
async def download(
|
async def download(
|
||||||
self,
|
self,
|
||||||
@ -306,19 +352,7 @@ class OutputDirectory:
|
|||||||
# Ensure parent directory exists
|
# Ensure parent directory exists
|
||||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Create tmp file
|
return FileSinkToken(self, path, local_path, heuristics, on_conflict)
|
||||||
for attempt in range(TRIES):
|
|
||||||
suffix_length = SUFFIX_LENGTH + 2 * attempt
|
|
||||||
tmp_path = self._tmp_path(local_path, suffix_length)
|
|
||||||
info = DownloadInfo(path, local_path, tmp_path,
|
|
||||||
heuristics, on_conflict)
|
|
||||||
try:
|
|
||||||
file = open(tmp_path, "xb")
|
|
||||||
return self._sink_context_manager(file, info)
|
|
||||||
except FileExistsError:
|
|
||||||
pass # Try again
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _update_metadata(self, info: DownloadInfo) -> None:
|
def _update_metadata(self, info: DownloadInfo) -> None:
|
||||||
if mtime := info.heuristics.mtime:
|
if mtime := info.heuristics.mtime:
|
||||||
|
@ -2,7 +2,11 @@ import asyncio
|
|||||||
import contextvars
|
import contextvars
|
||||||
import functools
|
import functools
|
||||||
import getpass
|
import getpass
|
||||||
from typing import Any, Callable, Optional, TypeVar
|
import sys
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any, Callable, Generic, Optional, Type, TypeVar
|
||||||
|
|
||||||
import bs4
|
import bs4
|
||||||
|
|
||||||
@ -56,3 +60,42 @@ async def prompt_yes_no(query: str, default: Optional[bool]) -> bool:
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
print("Please answer with 'y' or 'n'.")
|
print("Please answer with 'y' or 'n'.")
|
||||||
|
|
||||||
|
|
||||||
|
class ReusableAsyncContextManager(ABC, Generic[T]):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._active = False
|
||||||
|
self._stack = AsyncExitStack()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _on_aenter(self) -> T:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self) -> T:
|
||||||
|
if self._active:
|
||||||
|
raise RuntimeError("Nested or otherwise concurrent usage is not allowed")
|
||||||
|
|
||||||
|
self._active = True
|
||||||
|
await self._stack.__aenter__()
|
||||||
|
|
||||||
|
# See https://stackoverflow.com/a/13075071
|
||||||
|
try:
|
||||||
|
result: T = await self._on_aenter()
|
||||||
|
except: # noqa: E722 do not use bare 'except'
|
||||||
|
if not await self.__aexit__(*sys.exc_info()):
|
||||||
|
raise
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_value: Optional[BaseException],
|
||||||
|
traceback: Optional[TracebackType],
|
||||||
|
) -> Optional[bool]:
|
||||||
|
if not self._active:
|
||||||
|
raise RuntimeError("__aexit__ called too many times")
|
||||||
|
|
||||||
|
result = await self._stack.__aexit__(exc_type, exc_value, traceback)
|
||||||
|
self._active = False
|
||||||
|
return result
|
||||||
|
Loading…
Reference in New Issue
Block a user