diff --git a/PFERD/diva.py b/PFERD/diva.py index bfd0299..636ffeb 100644 --- a/PFERD/diva.py +++ b/PFERD/diva.py @@ -127,7 +127,7 @@ class DivaDownloader: with self._session.get(info.url, stream=True) as response: if response.status_code == 200: tmp_file = self._tmp_dir.new_path() - stream_to_path(response, tmp_file) + stream_to_path(response, tmp_file, info.path.name) self._organizer.accept_file(tmp_file, info.path) else: PRETTY.warning(f"Could not download file, got response {response.status_code}") diff --git a/PFERD/downloaders.py b/PFERD/downloaders.py index 48a82ee..94b8b9f 100644 --- a/PFERD/downloaders.py +++ b/PFERD/downloaders.py @@ -49,7 +49,6 @@ class HttpDownloader: ) return session - def download_all(self, infos: List[HttpDownloadInfo]) -> None: """ Download multiple files one after the other. @@ -58,7 +57,6 @@ class HttpDownloader: for info in infos: self.download(info) - def download(self, info: HttpDownloadInfo) -> None: """ Download a single file. @@ -67,7 +65,7 @@ class HttpDownloader: with self._session.get(info.url, params=info.parameters, stream=True) as response: if response.status_code == 200: tmp_file = self._tmp_dir.new_path() - stream_to_path(response, tmp_file) + stream_to_path(response, tmp_file, info.path.name) self._organizer.accept_file(tmp_file, info.path) else: # TODO use proper exception diff --git a/PFERD/ilias/downloader.py b/PFERD/ilias/downloader.py index 98ad388..d2efafb 100644 --- a/PFERD/ilias/downloader.py +++ b/PFERD/ilias/downloader.py @@ -124,7 +124,7 @@ class IliasDownloader: return False # Yay, we got the file :) - stream_to_path(response, target) + stream_to_path(response, target, info.path.name) return True @staticmethod diff --git a/PFERD/progress.py b/PFERD/progress.py new file mode 100644 index 0000000..eff86ce --- /dev/null +++ b/PFERD/progress.py @@ -0,0 +1,121 @@ +""" +A small progress bar implementation. +""" +import sys +from dataclasses import dataclass +from types import TracebackType +from typing import Optional, Type + +import requests +from rich.console import Console, ConsoleOptions, Control, RenderResult +from rich.live_render import LiveRender +from rich.progress import (BarColumn, DownloadColumn, Progress, TaskID, + TextColumn, TimeRemainingColumn, + TransferSpeedColumn) + +_progress: Progress = Progress( + TextColumn("[bold blue]{task.fields[name]}", justify="right"), + BarColumn(bar_width=None), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + "•", + TimeRemainingColumn(), + console=Console(file=sys.stdout) +) + + +def size_from_headers(response: requests.Response) -> Optional[int]: + """ + Return the size of the download based on the response headers. + + Arguments: + response {requests.Response} -- the response + + Returns: + Optional[int] -- the size + """ + if "Content-Length" in response.headers: + return int(response.headers["Content-Length"]) + return None + + +@dataclass +class ProgressSettings: + """ + Settings you can pass to customize the progress bar. + """ + name: str + max_size: int + + +def progress_for(settings: Optional[ProgressSettings]) -> 'ProgressContextManager': + """ + Returns a context manager that displays progress + + Returns: + ProgressContextManager -- the progress manager + """ + return ProgressContextManager(settings) + + +class ProgressContextManager: + """ + A context manager used for displaying progress. + """ + + def __init__(self, settings: Optional[ProgressSettings]): + self._settings = settings + self._task_id: Optional[TaskID] = None + + def __enter__(self) -> 'ProgressContextManager': + """Context manager entry function.""" + if not self._settings: + return self + + _progress.start() + self._task_id = _progress.add_task( + self._settings.name, + total=self._settings.max_size, + name=self._settings.name + ) + return self + + # pylint: disable=useless-return + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + """Context manager exit function. Removes the task.""" + if self._task_id is not None: + _progress.remove_task(self._task_id) + + if len(_progress.task_ids) == 0: + _progress.stop() + _progress.refresh() + + class _OneLineUp(LiveRender): + """ + Render a control code for moving one line upwards. + """ + + def __init__(self) -> None: + super().__init__("not rendered") + + def __console__(self, console: Console, options: ConsoleOptions) -> RenderResult: + yield Control(f"\r\x1b[1A") + + Console(file=sys.stdout).print(_OneLineUp()) + + return None + + def advance(self, amount: float) -> None: + """ + Advances the progress bar. + """ + if self._task_id is not None: + _progress.advance(self._task_id, amount) diff --git a/PFERD/utils.py b/PFERD/utils.py index 320b0a5..56c101a 100644 --- a/PFERD/utils.py +++ b/PFERD/utils.py @@ -9,6 +9,8 @@ from typing import Optional, Tuple, Union import bs4 import requests +from .progress import ProgressSettings, progress_for, size_from_headers + PathLike = Union[PurePath, str, Tuple[str, ...]] @@ -41,17 +43,33 @@ def soupify(response: requests.Response) -> bs4.BeautifulSoup: return bs4.BeautifulSoup(response.text, "html.parser") -def stream_to_path(response: requests.Response, target: Path, chunk_size: int = 1024 ** 2) -> None: +def stream_to_path( + response: requests.Response, + target: Path, + progress_name: Optional[str] = None, + chunk_size: int = 1024 ** 2 +) -> None: """ Download a requests response content to a file by streaming it. This function avoids excessive memory usage when downloading large files. The chunk_size is in bytes. + + If progress_name is None, no progress bar will be shown. Otherwise a progress + bar will appear, if the download is bigger than an internal threshold. """ with response: + length = size_from_headers(response) + if progress_name and length and int(length) > 1024 * 1024 * 10: # 10 MiB + settings: Optional[ProgressSettings] = ProgressSettings(progress_name, length) + else: + settings = None + with open(target, 'wb') as file_descriptor: - for chunk in response.iter_content(chunk_size=chunk_size): - file_descriptor.write(chunk) + with progress_for(settings) as progress: + for chunk in response.iter_content(chunk_size=chunk_size): + file_descriptor.write(chunk) + progress.advance(len(chunk)) def prompt_yes_no(question: str, default: Optional[bool] = None) -> bool: