diff --git a/PFERD/conductor.py b/PFERD/conductor.py index fef5a0e..121ed9a 100644 --- a/PFERD/conductor.py +++ b/PFERD/conductor.py @@ -1,6 +1,7 @@ import asyncio from contextlib import asynccontextmanager, contextmanager -from typing import AsyncIterator, Iterator, List, Optional +from types import TracebackType +from typing import AsyncIterator, Iterator, List, Optional, Type import rich from rich.progress import Progress, TaskID @@ -22,24 +23,30 @@ class TerminalConductor: self._progress = Progress() self._lines: List[str] = [] - def _start(self) -> None: - for line in self._lines: - rich.print(line) - self._lines = [] - - self._progress.start() - - def _stop(self) -> None: - self._progress.stop() - self._stopped = True - - async def start(self) -> None: + async def _start(self) -> None: async with self._lock: - self._start() + for line in self._lines: + rich.print(line) + self._lines = [] - async def stop(self) -> None: + self._progress.start() + + async def _stop(self) -> None: async with self._lock: - self._stop() + self._progress.stop() + self._stopped = True + + async def __aenter__(self) -> None: + await self._start() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + await self._stop() + return None def print(self, line: str) -> None: if self._stopped: @@ -50,11 +57,11 @@ class TerminalConductor: @asynccontextmanager async def exclusive_output(self) -> AsyncIterator[None]: async with self._lock: - self.stop() + self._stop() try: yield finally: - self.start() + self._start() @contextmanager def progress_bar( diff --git a/PFERD/crawler.py b/PFERD/crawler.py index 31aab5b..093ba91 100644 --- a/PFERD/crawler.py +++ b/PFERD/crawler.py @@ -63,11 +63,8 @@ class Crawler(ABC): return self.progress_bar(desc, total=size) async def run(self) -> None: - await self._conductor.start() - try: + async with self._conductor: await self.crawl() - finally: - await self._conductor.stop() @abstractmethod async def crawl(self) -> None: