pferd/PFERD/utils.py

102 lines
2.8 KiB
Python
Raw Normal View History

2021-04-29 15:26:10 +02:00
import asyncio
2021-04-29 15:47:52 +02:00
import contextvars
import functools
2021-04-29 15:26:10 +02:00
import getpass
import sys
from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from types import TracebackType
from typing import Any, Callable, Generic, Optional, Type, TypeVar
2020-04-20 17:15:47 +02:00
import bs4
2021-04-29 15:26:10 +02:00
T = TypeVar("T")
2021-04-29 15:26:10 +02:00
# TODO When switching to 3.9, use asyncio.to_thread instead of this
async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
# https://github.com/python/cpython/blob/8d47f92d46a92a5931b8f3dcb4a484df672fc4de/Lib/asyncio/threads.py
loop = asyncio.get_event_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
2021-04-29 15:47:52 +02:00
return await loop.run_in_executor(None, func_call) # type: ignore
2021-04-29 15:26:10 +02:00
2021-04-29 15:47:52 +02:00
async def ainput(prompt: str) -> str:
2021-04-29 15:26:10 +02:00
return await to_thread(lambda: input(prompt))
2021-04-29 15:47:52 +02:00
async def agetpass(prompt: str) -> str:
2021-04-29 15:26:10 +02:00
return await to_thread(lambda: getpass.getpass(prompt))
2021-05-16 14:32:53 +02:00
def soupify(data: bytes) -> bs4.BeautifulSoup:
"""
Parses HTML to a beautifulsoup object.
"""
return bs4.BeautifulSoup(data, "html.parser")
2021-04-29 15:26:10 +02:00
2021-05-16 14:32:53 +02:00
2021-04-29 15:26:10 +02:00
async def prompt_yes_no(query: str, default: Optional[bool]) -> bool:
2020-04-20 19:27:26 +02:00
"""
2021-04-27 12:41:49 +02:00
Asks the user a yes/no question and returns their choice.
2020-04-20 17:15:47 +02:00
"""
2020-04-20 14:29:28 +02:00
if default is True:
2021-04-27 12:41:49 +02:00
query += " [Y/n] "
2020-04-20 14:29:28 +02:00
elif default is False:
2021-04-27 12:41:49 +02:00
query += " [y/N] "
2020-04-20 14:29:28 +02:00
else:
2021-04-27 12:41:49 +02:00
query += " [y/n] "
2020-04-20 14:29:28 +02:00
while True:
2021-04-29 15:26:10 +02:00
response = (await ainput(query)).strip().lower()
2021-04-27 12:41:49 +02:00
if response == "y":
2020-04-20 14:29:28 +02:00
return True
2021-04-27 12:41:49 +02:00
elif response == "n":
2020-04-20 14:29:28 +02:00
return False
2021-04-27 12:41:49 +02:00
elif response == "" and default is not None:
2020-04-20 17:15:47 +02:00
return default
2021-04-27 12:41:49 +02:00
print("Please answer with 'y' or 'n'.")
class ReusableAsyncContextManager(ABC, Generic[T]):
def __init__(self) -> None:
self._active = False
self._stack = AsyncExitStack()
@abstractmethod
async def _on_aenter(self) -> T:
pass
async def __aenter__(self) -> T:
if self._active:
raise RuntimeError("Nested or otherwise concurrent usage is not allowed")
self._active = True
await self._stack.__aenter__()
# See https://stackoverflow.com/a/13075071
try:
result: T = await self._on_aenter()
except: # noqa: E722 do not use bare 'except'
if not await self.__aexit__(*sys.exc_info()):
raise
return result
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
if not self._active:
raise RuntimeError("__aexit__ called too many times")
result = await self._stack.__aexit__(exc_type, exc_value, traceback)
self._active = False
return result