diff --git a/PFERD/crawler.py b/PFERD/crawler.py index f5286b8..42f66a3 100644 --- a/PFERD/crawler.py +++ b/PFERD/crawler.py @@ -1,10 +1,8 @@ import asyncio from abc import ABC, abstractmethod -from contextlib import asynccontextmanager from datetime import datetime from pathlib import Path, PurePath -# TODO In Python 3.9 and above, AsyncContextManager is deprecated -from typing import Any, AsyncContextManager, AsyncIterator, Awaitable, Callable, Dict, Optional, TypeVar +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, TypeVar import aiohttp from rich.markup import escape @@ -13,9 +11,10 @@ from .authenticator import Authenticator from .config import Config, Section from .limiter import Limiter from .logging import ProgressBar, log -from .output_dir import FileSink, OnConflict, OutputDirectory, OutputDirError, Redownload +from .output_dir import FileSink, FileSinkToken, OnConflict, OutputDirectory, OutputDirError, Redownload from .report import MarkConflictError, MarkDuplicateError from .transformer import Transformer +from .utils import ReusableAsyncContextManager from .version import NAME, VERSION @@ -88,6 +87,36 @@ def anoncritical(f: AWrapped) -> AWrapped: return wrapper # type: ignore +class CrawlToken(ReusableAsyncContextManager[ProgressBar]): + def __init__(self, limiter: Limiter, desc: str): + super().__init__() + + self._limiter = limiter + self._desc = desc + + async def _on_aenter(self) -> ProgressBar: + await self._stack.enter_async_context(self._limiter.limit_crawl()) + bar = self._stack.enter_context(log.crawl_bar(self._desc)) + + return bar + + +class DownloadToken(ReusableAsyncContextManager[Tuple[ProgressBar, FileSink]]): + def __init__(self, limiter: Limiter, fs_token: FileSinkToken, desc: str): + super().__init__() + + self._limiter = limiter + self._fs_token = fs_token + self._desc = desc + + async def _on_aenter(self) -> Tuple[ProgressBar, FileSink]: + await self._stack.enter_async_context(self._limiter.limit_crawl()) + sink = await self._stack.enter_async_context(self._fs_token) + bar = self._stack.enter_context(log.crawl_bar(self._desc)) + + return bar, sink + + class CrawlerSection(Section): def output_dir(self, name: str) -> Path: # TODO Use removeprefix() after switching to 3.9 @@ -190,30 +219,12 @@ class Crawler(ABC): section.on_conflict(), ) - @asynccontextmanager - async def crawl_bar( - self, - path: PurePath, - total: Optional[int] = None, - ) -> AsyncIterator[ProgressBar]: + async def crawl(self, path: PurePath) -> Optional[CrawlToken]: + if self._transformer.transform(path) is None: + return None + desc = f"[bold bright_cyan]Crawling[/] {escape(str(path))}" - async with self._limiter.limit_crawl(): - with log.crawl_bar(desc, total=total) as bar: - yield bar - - @asynccontextmanager - async def download_bar( - self, - path: PurePath, - total: Optional[int] = None, - ) -> AsyncIterator[ProgressBar]: - desc = f"[bold bright_cyan]Downloading[/] {escape(str(path))}" - async with self._limiter.limit_download(): - with log.download_bar(desc, total=total) as bar: - yield bar - - def should_crawl(self, path: PurePath) -> bool: - return self._transformer.transform(path) is not None + return CrawlToken(self._limiter, desc) async def download( self, @@ -221,13 +232,17 @@ class Crawler(ABC): mtime: Optional[datetime] = None, redownload: Optional[Redownload] = None, on_conflict: Optional[OnConflict] = None, - ) -> Optional[AsyncContextManager[FileSink]]: + ) -> Optional[DownloadToken]: transformed_path = self._transformer.transform(path) if transformed_path is None: return None - return await self._output_dir.download( - transformed_path, mtime, redownload, on_conflict) + fs_token = await self._output_dir.download(transformed_path, mtime, redownload, on_conflict) + if fs_token is None: + return None + + desc = f"[bold bright_cyan]Downloading[/] {escape(str(path))}" + return DownloadToken(self._limiter, fs_token, desc) async def cleanup(self) -> None: await self._output_dir.cleanup() @@ -239,10 +254,10 @@ class Crawler(ABC): """ with log.show_progress(): - await self.crawl() + await self._run() @abstractmethod - async def crawl(self) -> None: + async def _run(self) -> None: """ Overwrite this function if you are writing a crawler. diff --git a/PFERD/output_dir.py b/PFERD/output_dir.py index ee4910e..fef6914 100644 --- a/PFERD/output_dir.py +++ b/PFERD/output_dir.py @@ -8,8 +8,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum from pathlib import Path, PurePath -# TODO In Python 3.9 and above, AsyncContextManager is deprecated -from typing import AsyncContextManager, BinaryIO, Iterator, Optional, Tuple +from typing import BinaryIO, Iterator, Optional, Tuple from rich.markup import escape @@ -307,7 +306,7 @@ class OutputDirectory: mtime: Optional[datetime] = None, redownload: Optional[Redownload] = None, on_conflict: Optional[OnConflict] = None, - ) -> Optional[AsyncContextManager[FileSink]]: + ) -> Optional[FileSinkToken]: """ May throw an OutputDirError, a MarkDuplicateError or a MarkConflictError.