Use conductor via context manager

This commit is contained in:
Joscha 2021-04-29 14:23:28 +02:00
parent 6431a3fb3d
commit 2e85d26b6b
2 changed files with 26 additions and 22 deletions

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from typing import AsyncIterator, Iterator, List, Optional from types import TracebackType
from typing import AsyncIterator, Iterator, List, Optional, Type
import rich import rich
from rich.progress import Progress, TaskID from rich.progress import Progress, TaskID
@ -22,24 +23,30 @@ class TerminalConductor:
self._progress = Progress() self._progress = Progress()
self._lines: List[str] = [] self._lines: List[str] = []
def _start(self) -> None: async def _start(self) -> None:
for line in self._lines:
rich.print(line)
self._lines = []
self._progress.start()
def _stop(self) -> None:
self._progress.stop()
self._stopped = True
async def start(self) -> None:
async with self._lock: async with self._lock:
self._start() for line in self._lines:
rich.print(line)
self._lines = []
async def stop(self) -> None: self._progress.start()
async def _stop(self) -> None:
async with self._lock: async with self._lock:
self._stop() self._progress.stop()
self._stopped = True
async def __aenter__(self) -> None:
await self._start()
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
await self._stop()
return None
def print(self, line: str) -> None: def print(self, line: str) -> None:
if self._stopped: if self._stopped:
@ -50,11 +57,11 @@ class TerminalConductor:
@asynccontextmanager @asynccontextmanager
async def exclusive_output(self) -> AsyncIterator[None]: async def exclusive_output(self) -> AsyncIterator[None]:
async with self._lock: async with self._lock:
self.stop() self._stop()
try: try:
yield yield
finally: finally:
self.start() self._start()
@contextmanager @contextmanager
def progress_bar( def progress_bar(

View File

@ -63,11 +63,8 @@ class Crawler(ABC):
return self.progress_bar(desc, total=size) return self.progress_bar(desc, total=size)
async def run(self) -> None: async def run(self) -> None:
await self._conductor.start() async with self._conductor:
try:
await self.crawl() await self.crawl()
finally:
await self._conductor.stop()
@abstractmethod @abstractmethod
async def crawl(self) -> None: async def crawl(self) -> None: