mirror of
https://github.com/Garmelon/PFERD.git
synced 2023-12-21 10:23:01 +01:00
366 lines
11 KiB
Python
366 lines
11 KiB
Python
|
import filecmp
|
||
|
import os
|
||
|
import random
|
||
|
import shutil
|
||
|
import string
|
||
|
from contextlib import asynccontextmanager
|
||
|
from dataclasses import dataclass
|
||
|
from datetime import datetime
|
||
|
from enum import Enum
|
||
|
from pathlib import Path, PurePath
|
||
|
# TODO In Python 3.9 and above, AsyncContextManager is deprecated
|
||
|
from typing import AsyncContextManager, AsyncIterator, BinaryIO, Optional
|
||
|
|
||
|
from .conductor import TerminalConductor
|
||
|
from .report import MarkConflictException, MarkDuplicateException, Report
|
||
|
from .utils import prompt_yes_no
|
||
|
|
||
|
SUFFIX_CHARS = string.ascii_lowercase + string.digits
|
||
|
SUFFIX_LENGTH = 6
|
||
|
TRIES = 5
|
||
|
|
||
|
|
||
|
class OutputDirException(Exception):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Redownload(Enum):
|
||
|
NEVER = "never"
|
||
|
NEVER_SMART = "never-smart"
|
||
|
ALWAYS = "always"
|
||
|
ALWAYS_SMART = "always-smart"
|
||
|
|
||
|
|
||
|
class OnConflict(Enum):
|
||
|
PROMPT = "prompt"
|
||
|
LOCAL_FIRST = "local-first"
|
||
|
REMOTE_FIRST = "remote-first"
|
||
|
NO_DELETE = "no-delete"
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Heuristics:
|
||
|
mtime: Optional[datetime]
|
||
|
|
||
|
|
||
|
class FileSink:
|
||
|
def __init__(self, file: BinaryIO):
|
||
|
self._file = file
|
||
|
self._done = False
|
||
|
|
||
|
@property
|
||
|
def file(self) -> BinaryIO:
|
||
|
return self._file
|
||
|
|
||
|
def done(self) -> None:
|
||
|
self._done = True
|
||
|
|
||
|
def is_done(self) -> bool:
|
||
|
return self._done
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class DownloadInfo:
|
||
|
path: PurePath
|
||
|
local_path: Path
|
||
|
tmp_path: Path
|
||
|
heuristics: Heuristics
|
||
|
on_conflict: OnConflict
|
||
|
success: bool = False
|
||
|
|
||
|
|
||
|
class OutputDirectory:
|
||
|
def __init__(
|
||
|
self,
|
||
|
root: Path,
|
||
|
redownload: Redownload,
|
||
|
on_conflict: OnConflict,
|
||
|
conductor: TerminalConductor,
|
||
|
):
|
||
|
self._root = root
|
||
|
self._redownload = redownload
|
||
|
self._on_conflict = on_conflict
|
||
|
self._conductor = conductor
|
||
|
|
||
|
self._report = Report()
|
||
|
|
||
|
def _mark(self, path: PurePath) -> None:
|
||
|
"""
|
||
|
May throw an OutputDirException
|
||
|
"""
|
||
|
|
||
|
try:
|
||
|
self._report.mark(path)
|
||
|
except MarkDuplicateException:
|
||
|
msg = "Another file has already been placed here."
|
||
|
raise OutputDirException(msg)
|
||
|
except MarkConflictException as e:
|
||
|
msg = f"Collides with other file: {e.collides_with}"
|
||
|
raise OutputDirException(msg)
|
||
|
|
||
|
def _resolve(self, path: PurePath) -> Path:
|
||
|
"""
|
||
|
May throw an OutputDirException.
|
||
|
"""
|
||
|
|
||
|
if ".." in path.parts:
|
||
|
msg = f"Path {path} contains forbidden '..'"
|
||
|
raise OutputDirException(msg)
|
||
|
return self._root / path
|
||
|
|
||
|
def _should_download(
|
||
|
self,
|
||
|
local_path: Path,
|
||
|
heuristics: Heuristics,
|
||
|
redownload: Redownload,
|
||
|
) -> bool:
|
||
|
# If we don't have a *file* at the local path, we'll always redownload
|
||
|
# since we know that the remote is different from the local files. This
|
||
|
# includes the case where no local file exists.
|
||
|
if not local_path.is_file():
|
||
|
return True
|
||
|
|
||
|
if redownload == Redownload.NEVER:
|
||
|
return False
|
||
|
elif redownload == Redownload.ALWAYS:
|
||
|
return True
|
||
|
|
||
|
stat = local_path.stat()
|
||
|
|
||
|
remote_newer = None
|
||
|
if mtime := heuristics.mtime:
|
||
|
remote_newer = mtime.timestamp() > stat.st_mtime
|
||
|
|
||
|
if redownload == Redownload.NEVER_SMART:
|
||
|
if remote_newer is None:
|
||
|
return False
|
||
|
else:
|
||
|
return remote_newer
|
||
|
elif redownload == Redownload.ALWAYS_SMART:
|
||
|
if remote_newer is None:
|
||
|
return True
|
||
|
else:
|
||
|
return not remote_newer
|
||
|
|
||
|
# This should never be reached
|
||
|
raise ValueError(f"{redownload!r} is not a valid redownload policy")
|
||
|
|
||
|
# The following conflict resolution functions all return False if the local
|
||
|
# file(s) should be kept and True if they should be replaced by the remote
|
||
|
# files.
|
||
|
|
||
|
async def _conflict_lfrf(
|
||
|
self,
|
||
|
on_conflict: OnConflict,
|
||
|
path: PurePath,
|
||
|
) -> bool:
|
||
|
if on_conflict == OnConflict.PROMPT:
|
||
|
async with self._conductor.exclusive_output():
|
||
|
prompt = f"Replace {path} with remote file?"
|
||
|
return await prompt_yes_no(prompt, default=False)
|
||
|
elif on_conflict == OnConflict.LOCAL_FIRST:
|
||
|
return False
|
||
|
elif on_conflict == OnConflict.REMOTE_FIRST:
|
||
|
return True
|
||
|
elif on_conflict == OnConflict.NO_DELETE:
|
||
|
return True
|
||
|
|
||
|
# This should never be reached
|
||
|
raise ValueError(f"{on_conflict!r} is not a valid conflict policy")
|
||
|
|
||
|
async def _conflict_ldrf(
|
||
|
self,
|
||
|
on_conflict: OnConflict,
|
||
|
path: PurePath,
|
||
|
) -> bool:
|
||
|
if on_conflict == OnConflict.PROMPT:
|
||
|
async with self._conductor.exclusive_output():
|
||
|
prompt = f"Recursively delete {path} and replace with remote file?"
|
||
|
return await prompt_yes_no(prompt, default=False)
|
||
|
elif on_conflict == OnConflict.LOCAL_FIRST:
|
||
|
return False
|
||
|
elif on_conflict == OnConflict.REMOTE_FIRST:
|
||
|
return True
|
||
|
elif on_conflict == OnConflict.NO_DELETE:
|
||
|
return False
|
||
|
|
||
|
# This should never be reached
|
||
|
raise ValueError(f"{on_conflict!r} is not a valid conflict policy")
|
||
|
|
||
|
async def _conflict_lfrd(
|
||
|
self,
|
||
|
on_conflict: OnConflict,
|
||
|
path: PurePath,
|
||
|
parent: PurePath,
|
||
|
) -> bool:
|
||
|
if on_conflict == OnConflict.PROMPT:
|
||
|
async with self._conductor.exclusive_output():
|
||
|
prompt = f"Delete {parent} so remote file {path} can be downloaded?"
|
||
|
return await prompt_yes_no(prompt, default=False)
|
||
|
elif on_conflict == OnConflict.LOCAL_FIRST:
|
||
|
return False
|
||
|
elif on_conflict == OnConflict.REMOTE_FIRST:
|
||
|
return True
|
||
|
elif on_conflict == OnConflict.NO_DELETE:
|
||
|
return False
|
||
|
|
||
|
# This should never be reached
|
||
|
raise ValueError(f"{on_conflict!r} is not a valid conflict policy")
|
||
|
|
||
|
async def _conflict_delete_lf(
|
||
|
self,
|
||
|
on_conflict: OnConflict,
|
||
|
path: PurePath,
|
||
|
) -> bool:
|
||
|
if on_conflict == OnConflict.PROMPT:
|
||
|
async with self._conductor.exclusive_output():
|
||
|
prompt = f"Delete {path}?"
|
||
|
return await prompt_yes_no(prompt, default=False)
|
||
|
elif on_conflict == OnConflict.LOCAL_FIRST:
|
||
|
return False
|
||
|
elif on_conflict == OnConflict.REMOTE_FIRST:
|
||
|
return True
|
||
|
elif on_conflict == OnConflict.NO_DELETE:
|
||
|
return False
|
||
|
|
||
|
# This should never be reached
|
||
|
raise ValueError(f"{on_conflict!r} is not a valid conflict policy")
|
||
|
|
||
|
def _tmp_path(self, base: Path, suffix_length: int) -> Path:
|
||
|
prefix = "" if base.name.startswith(".") else "."
|
||
|
suffix = random.choices(SUFFIX_CHARS, k=suffix_length)
|
||
|
name = f"{prefix}{base.name}.tmp.{suffix}"
|
||
|
return base.parent / name
|
||
|
|
||
|
@asynccontextmanager
|
||
|
async def _sink_context_manager(
|
||
|
self,
|
||
|
file: BinaryIO,
|
||
|
info: DownloadInfo,
|
||
|
) -> AsyncIterator[FileSink]:
|
||
|
sink = FileSink(file)
|
||
|
try:
|
||
|
with file:
|
||
|
yield sink
|
||
|
finally:
|
||
|
info.success = sink.is_done()
|
||
|
await self._after_download(info)
|
||
|
|
||
|
async def download(
|
||
|
self,
|
||
|
path: PurePath,
|
||
|
mtime: Optional[datetime] = None,
|
||
|
redownload: Optional[Redownload] = None,
|
||
|
on_conflict: Optional[OnConflict] = None,
|
||
|
) -> Optional[AsyncContextManager[FileSink]]:
|
||
|
"""
|
||
|
May throw an OutputDirException.
|
||
|
"""
|
||
|
|
||
|
heuristics = Heuristics(mtime)
|
||
|
redownload = self._redownload if redownload is None else redownload
|
||
|
on_conflict = self._on_conflict if on_conflict is None else on_conflict
|
||
|
local_path = self._resolve(path)
|
||
|
|
||
|
self._mark(path)
|
||
|
|
||
|
if not self._should_download(local_path, heuristics, redownload):
|
||
|
return None
|
||
|
|
||
|
# Detect and solve local-dir-remote-file conflict
|
||
|
if local_path.is_dir():
|
||
|
if await self._conflict_ldrf(on_conflict, path):
|
||
|
shutil.rmtree(local_path)
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
# Detect and solve local-file-remote-dir conflict
|
||
|
for parent in path.parents:
|
||
|
local_parent = self._resolve(parent)
|
||
|
if local_parent.exists() and not local_parent.is_dir():
|
||
|
if await self._conflict_lfrd(on_conflict, path, parent):
|
||
|
local_parent.unlink()
|
||
|
break
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
# Ensure parent directory exists
|
||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
# Create tmp file
|
||
|
for attempt in range(TRIES):
|
||
|
suffix_length = SUFFIX_LENGTH + 2 * attempt
|
||
|
tmp_path = self._tmp_path(local_path, suffix_length)
|
||
|
info = DownloadInfo(path, local_path, tmp_path,
|
||
|
heuristics, on_conflict)
|
||
|
try:
|
||
|
file = open(tmp_path, "bx")
|
||
|
return self._sink_context_manager(file, info)
|
||
|
except FileExistsError:
|
||
|
pass # Try again
|
||
|
|
||
|
return None
|
||
|
|
||
|
async def _after_download(self, info: DownloadInfo) -> None:
|
||
|
changed = False
|
||
|
|
||
|
if not info.success:
|
||
|
info.tmp_path.unlink()
|
||
|
return
|
||
|
|
||
|
# Solve conflicts arising from existing local file
|
||
|
if info.local_path.exists():
|
||
|
changed = True
|
||
|
if filecmp.cmp(info.local_path, info.tmp_path):
|
||
|
info.tmp_path.unlink()
|
||
|
return
|
||
|
|
||
|
if not await self._conflict_lfrf(info.on_conflict, info.path):
|
||
|
info.tmp_path.unlink()
|
||
|
return
|
||
|
|
||
|
# Modify metadata if necessary
|
||
|
if mtime := info.heuristics.mtime:
|
||
|
# TODO Pick an implementation
|
||
|
# Rounding up to avoid inaccuracies in how the OS stores timestamps
|
||
|
# mtimestamp = math.ceil(mtime.timestamp())
|
||
|
mtimestamp = mtime.timestamp()
|
||
|
os.utime(info.tmp_path, times=(mtimestamp, mtimestamp))
|
||
|
|
||
|
info.tmp_path.replace(info.local_path)
|
||
|
|
||
|
if changed:
|
||
|
self._report.change_file(info.path)
|
||
|
else:
|
||
|
self._report.add_file(info.path)
|
||
|
|
||
|
def cleanup(self) -> None:
|
||
|
self._cleanup_dir(self._root, PurePath())
|
||
|
|
||
|
def _cleanup(self, path: Path, pure: PurePath) -> None:
|
||
|
if path.is_dir():
|
||
|
self._cleanup_dir(path, pure)
|
||
|
elif path.is_file():
|
||
|
self._cleanup_file(path, pure)
|
||
|
|
||
|
def _cleanup_dir(self, path: Path, pure: PurePath) -> None:
|
||
|
for child in path.iterdir():
|
||
|
pure_child = pure / child.name
|
||
|
self._cleanup(child, pure_child)
|
||
|
|
||
|
try:
|
||
|
path.rmdir()
|
||
|
except OSError:
|
||
|
pass
|
||
|
|
||
|
def _cleanup_file(self, path: Path, pure: PurePath) -> None:
|
||
|
if self._report.marked(pure):
|
||
|
return
|
||
|
|
||
|
if self._conflict_delete_lf(self._on_conflict, pure):
|
||
|
try:
|
||
|
path.unlink()
|
||
|
self._report.delete_file(pure)
|
||
|
except OSError:
|
||
|
pass
|