Run async input and password getters in daemon thread

Previously, it ran in the event loop's default executor, which would block until
all its workers were done working.

If Ctrl+C was pressed while input or a password were being read, the
asyncio.run() call in the main thread would be interrupted however, not the
input thread. This meant that multiple key presses (either enter or a second
Ctrl+C) were necessary to stop a running PFERD in some circumstances.

This change instead runs the input functions in daemon threads so they exit as
soon as the main thread exits.
This commit is contained in:
Joscha 2021-05-22 18:37:53 +02:00
parent dfde0e2310
commit 552cd82802

View File

@ -1,8 +1,7 @@
import asyncio import asyncio
import contextvars
import functools
import getpass import getpass
import sys import sys
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from types import TracebackType from types import TracebackType
@ -14,21 +13,25 @@ import bs4
T = TypeVar("T") T = TypeVar("T")
# TODO When switching to 3.9, use asyncio.to_thread instead of this async def in_daemon_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
async def to_thread(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: loop = asyncio.get_running_loop()
# https://github.com/python/cpython/blob/8d47f92d46a92a5931b8f3dcb4a484df672fc4de/Lib/asyncio/threads.py future: asyncio.Future[T] = asyncio.Future()
loop = asyncio.get_event_loop()
ctx = contextvars.copy_context() def thread_func() -> None:
func_call = functools.partial(ctx.run, func, *args, **kwargs) result = func()
return await loop.run_in_executor(None, func_call) # type: ignore loop.call_soon_threadsafe(future.set_result, result)
threading.Thread(target=thread_func, daemon=True).start()
return await future
async def ainput(prompt: str) -> str: async def ainput(prompt: str) -> str:
return await to_thread(lambda: input(prompt)) return await in_daemon_thread(lambda: input(prompt))
async def agetpass(prompt: str) -> str: async def agetpass(prompt: str) -> str:
return await to_thread(lambda: getpass.getpass(prompt)) return await in_daemon_thread(lambda: getpass.getpass(prompt))
def soupify(data: bytes) -> bs4.BeautifulSoup: def soupify(data: bytes) -> bs4.BeautifulSoup: