2021-04-29 15:26:10 +02:00
|
|
|
import asyncio
|
|
|
|
import getpass
|
2021-05-19 17:16:23 +02:00
|
|
|
import sys
|
2021-05-22 18:37:53 +02:00
|
|
|
import threading
|
2021-05-19 17:16:23 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from contextlib import AsyncExitStack
|
|
|
|
from types import TracebackType
|
2021-05-19 21:34:36 +02:00
|
|
|
from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
|
|
|
|
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
2020-04-20 17:15:47 +02:00
|
|
|
|
2021-05-15 15:18:51 +02:00
|
|
|
import bs4
|
|
|
|
|
2021-04-29 15:26:10 +02:00
|
|
|
T = TypeVar("T")
|
2018-11-24 09:27:33 +01:00
|
|
|
|
2021-04-29 15:26:10 +02:00
|
|
|
|
2021-05-22 18:37:53 +02:00
|
|
|
async def in_daemon_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
future: asyncio.Future[T] = asyncio.Future()
|
|
|
|
|
|
|
|
def thread_func() -> None:
|
|
|
|
result = func()
|
|
|
|
loop.call_soon_threadsafe(future.set_result, result)
|
|
|
|
|
|
|
|
threading.Thread(target=thread_func, daemon=True).start()
|
|
|
|
|
|
|
|
return await future
|
2021-04-29 15:26:10 +02:00
|
|
|
|
|
|
|
|
2021-04-29 15:47:52 +02:00
|
|
|
async def ainput(prompt: str) -> str:
|
2021-05-22 18:37:53 +02:00
|
|
|
return await in_daemon_thread(lambda: input(prompt))
|
2021-04-29 15:26:10 +02:00
|
|
|
|
|
|
|
|
2021-04-29 15:47:52 +02:00
|
|
|
async def agetpass(prompt: str) -> str:
|
2021-05-22 18:37:53 +02:00
|
|
|
return await in_daemon_thread(lambda: getpass.getpass(prompt))
|
2021-04-29 15:26:10 +02:00
|
|
|
|
2021-05-16 14:32:53 +02:00
|
|
|
|
2021-05-15 15:18:51 +02:00
|
|
|
def soupify(data: bytes) -> bs4.BeautifulSoup:
|
|
|
|
"""
|
|
|
|
Parses HTML to a beautifulsoup object.
|
|
|
|
"""
|
|
|
|
|
|
|
|
return bs4.BeautifulSoup(data, "html.parser")
|
2021-04-29 15:26:10 +02:00
|
|
|
|
2021-05-16 14:32:53 +02:00
|
|
|
|
2021-05-19 21:34:36 +02:00
|
|
|
def url_set_query_param(url: str, param: str, value: str) -> str:
|
|
|
|
"""
|
|
|
|
Set a query parameter in an url, overwriting existing ones with the same name.
|
|
|
|
"""
|
|
|
|
scheme, netloc, path, query, fragment = urlsplit(url)
|
|
|
|
query_parameters = parse_qs(query)
|
|
|
|
query_parameters[param] = [value]
|
|
|
|
new_query_string = urlencode(query_parameters, doseq=True)
|
|
|
|
|
|
|
|
return urlunsplit((scheme, netloc, path, new_query_string, fragment))
|
|
|
|
|
|
|
|
|
|
|
|
def url_set_query_params(url: str, params: Dict[str, str]) -> str:
|
|
|
|
"""
|
|
|
|
Sets multiple query parameters in an url, overwriting existing ones.
|
|
|
|
"""
|
|
|
|
result = url
|
|
|
|
|
|
|
|
for key, val in params.items():
|
|
|
|
result = url_set_query_param(result, key, val)
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
2021-04-29 15:26:10 +02:00
|
|
|
async def prompt_yes_no(query: str, default: Optional[bool]) -> bool:
|
2020-04-20 19:27:26 +02:00
|
|
|
"""
|
2021-04-27 12:41:49 +02:00
|
|
|
Asks the user a yes/no question and returns their choice.
|
2020-04-20 17:15:47 +02:00
|
|
|
"""
|
|
|
|
|
2020-04-20 14:29:28 +02:00
|
|
|
if default is True:
|
2021-04-27 12:41:49 +02:00
|
|
|
query += " [Y/n] "
|
2020-04-20 14:29:28 +02:00
|
|
|
elif default is False:
|
2021-04-27 12:41:49 +02:00
|
|
|
query += " [y/N] "
|
2020-04-20 14:29:28 +02:00
|
|
|
else:
|
2021-04-27 12:41:49 +02:00
|
|
|
query += " [y/n] "
|
2020-04-20 14:29:28 +02:00
|
|
|
|
|
|
|
while True:
|
2021-04-29 15:26:10 +02:00
|
|
|
response = (await ainput(query)).strip().lower()
|
2021-04-27 12:41:49 +02:00
|
|
|
if response == "y":
|
2020-04-20 14:29:28 +02:00
|
|
|
return True
|
2021-04-27 12:41:49 +02:00
|
|
|
elif response == "n":
|
2020-04-20 14:29:28 +02:00
|
|
|
return False
|
2021-04-27 12:41:49 +02:00
|
|
|
elif response == "" and default is not None:
|
2020-04-20 17:15:47 +02:00
|
|
|
return default
|
2021-04-27 12:41:49 +02:00
|
|
|
|
|
|
|
print("Please answer with 'y' or 'n'.")
|
2021-05-19 17:16:23 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ReusableAsyncContextManager(ABC, Generic[T]):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self._active = False
|
|
|
|
self._stack = AsyncExitStack()
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def _on_aenter(self) -> T:
|
|
|
|
pass
|
|
|
|
|
|
|
|
async def __aenter__(self) -> T:
|
|
|
|
if self._active:
|
|
|
|
raise RuntimeError("Nested or otherwise concurrent usage is not allowed")
|
|
|
|
|
|
|
|
self._active = True
|
|
|
|
await self._stack.__aenter__()
|
|
|
|
|
|
|
|
# See https://stackoverflow.com/a/13075071
|
|
|
|
try:
|
|
|
|
result: T = await self._on_aenter()
|
|
|
|
except: # noqa: E722 do not use bare 'except'
|
|
|
|
if not await self.__aexit__(*sys.exc_info()):
|
|
|
|
raise
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
async def __aexit__(
|
|
|
|
self,
|
|
|
|
exc_type: Optional[Type[BaseException]],
|
|
|
|
exc_value: Optional[BaseException],
|
|
|
|
traceback: Optional[TracebackType],
|
|
|
|
) -> Optional[bool]:
|
|
|
|
if not self._active:
|
|
|
|
raise RuntimeError("__aexit__ called too many times")
|
|
|
|
|
|
|
|
result = await self._stack.__aexit__(exc_type, exc_value, traceback)
|
|
|
|
self._active = False
|
|
|
|
return result
|