Test and fix exclusive output

This commit is contained in:
Joscha 2021-04-29 15:26:10 +02:00
parent 2e85d26b6b
commit d96a361325
4 changed files with 55 additions and 19 deletions

View File

@ -3,7 +3,6 @@ from contextlib import asynccontextmanager, contextmanager
from types import TracebackType from types import TracebackType
from typing import AsyncIterator, Iterator, List, Optional, Type from typing import AsyncIterator, Iterator, List, Optional, Type
import rich
from rich.progress import Progress, TaskID from rich.progress import Progress, TaskID
@ -24,20 +23,26 @@ class TerminalConductor:
self._lines: List[str] = [] self._lines: List[str] = []
async def _start(self) -> None: async def _start(self) -> None:
async with self._lock: for task in self._progress.tasks:
for line in self._lines: task.visible = True
rich.print(line) self._progress.start()
self._lines = []
self._progress.start() self._stopped = False
for line in self._lines:
self.print(line)
self._lines = []
async def _stop(self) -> None: async def _stop(self) -> None:
async with self._lock: self._stopped = True
self._progress.stop()
self._stopped = True for task in self._progress.tasks:
task.visible = False
self._progress.stop()
async def __aenter__(self) -> None: async def __aenter__(self) -> None:
await self._start() async with self._lock:
await self._start()
async def __aexit__( async def __aexit__(
self, self,
@ -45,23 +50,24 @@ class TerminalConductor:
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],
) -> Optional[bool]: ) -> Optional[bool]:
await self._stop() async with self._lock:
await self._stop()
return None return None
def print(self, line: str) -> None: def print(self, line: str) -> None:
if self._stopped: if self._stopped:
self._lines.append(line) self._lines.append(line)
else: else:
rich.print(line) self._progress.console.print(line)
@asynccontextmanager @asynccontextmanager
async def exclusive_output(self) -> AsyncIterator[None]: async def exclusive_output(self) -> AsyncIterator[None]:
async with self._lock: async with self._lock:
self._stop() await self._stop()
try: try:
yield yield
finally: finally:
self._start() await self._start()
@contextmanager @contextmanager
def progress_bar( def progress_bar(

View File

@ -38,6 +38,9 @@ class Crawler(ABC):
def print(self, text: str) -> None: def print(self, text: str) -> None:
self._conductor.print(text) self._conductor.print(text)
def exclusive_output(self):
return self._conductor.exclusive_output()
@asynccontextmanager @asynccontextmanager
async def progress_bar( async def progress_bar(
self, self,

View File

@ -6,6 +6,7 @@ from typing import Any
from rich.markup import escape from rich.markup import escape
from ..crawler import Crawler from ..crawler import Crawler
from ..utils import ainput
DUMMY_TREE = { DUMMY_TREE = {
"Blätter": { "Blätter": {
@ -17,7 +18,7 @@ DUMMY_TREE = {
"Lösungen": { "Lösungen": {
"Blatt_01_Lösung.pdf": (), "Blatt_01_Lösung.pdf": (),
"Blatt_02_Lösung.pdf": (), "Blatt_02_Lösung.pdf": (),
"Blatt_03_Lösung.pdf": (), "Blatt_03_Lösung.pdf": True,
"Blatt_04_Lösung.pdf": (), "Blatt_04_Lösung.pdf": (),
"Blatt_05_Lösung.pdf": (), "Blatt_05_Lösung.pdf": (),
}, },
@ -39,7 +40,10 @@ class DummyCrawler(Crawler):
await self._crawl_entry(Path(), DUMMY_TREE) await self._crawl_entry(Path(), DUMMY_TREE)
async def _crawl_entry(self, path: Path, value: Any) -> None: async def _crawl_entry(self, path: Path, value: Any) -> None:
if value == (): if value is True:
async with self.exclusive_output():
await ainput(f"File {path}, please press enter: ")
if value == () or value is True:
n = random.randint(5, 20) n = random.randint(5, 20)
async with self.download_bar(path, n) as bar: async with self.download_bar(path, n) as bar:
await asyncio.sleep(random.random() / 2) await asyncio.sleep(random.random() / 2)

View File

@ -1,7 +1,30 @@
from typing import Optional import functools
import contextvars
import asyncio
import getpass
from typing import Any, Callable, Optional, TypeVar
T = TypeVar("T")
def prompt_yes_no(query: str, default: Optional[bool]) -> bool: # TODO When switching to 3.9, use asyncio.to_thread instead of this
async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
# https://github.com/python/cpython/blob/8d47f92d46a92a5931b8f3dcb4a484df672fc4de/Lib/asyncio/threads.py
loop = asyncio.get_event_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(None, func_call)
async def ainput(prompt: Optional[str] = None) -> str:
return await to_thread(lambda: input(prompt))
async def agetpass(prompt: Optional[str] = None) -> str:
return await to_thread(lambda: getpass.getpass(prompt))
async def prompt_yes_no(query: str, default: Optional[bool]) -> bool:
""" """
Asks the user a yes/no question and returns their choice. Asks the user a yes/no question and returns their choice.
""" """
@ -14,7 +37,7 @@ def prompt_yes_no(query: str, default: Optional[bool]) -> bool:
query += " [y/n] " query += " [y/n] "
while True: while True:
response = input(query).strip().lower() response = (await ainput(query)).strip().lower()
if response == "y": if response == "y":
return True return True
elif response == "n": elif response == "n":