mirror of
				https://github.com/Garmelon/PFERD.git
				synced 2025-11-04 06:32:52 +01:00 
			
		
		
		
	- Renamed module and class because "conductor" didn't make a lot of sense - Used singleton approach (there's only one stdout after all) - Redesigned progress bars (now with download speed!)
		
			
				
	
	
		
			315 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			315 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from abc import ABC, abstractmethod
 | 
						|
from contextlib import asynccontextmanager
 | 
						|
from datetime import datetime
 | 
						|
from pathlib import Path, PurePath
 | 
						|
# TODO In Python 3.9 and above, AsyncContextManager is deprecated
 | 
						|
from typing import Any, AsyncContextManager, AsyncIterator, Awaitable, Callable, Dict, Optional, TypeVar
 | 
						|
 | 
						|
import aiohttp
 | 
						|
from rich.markup import escape
 | 
						|
 | 
						|
from .authenticator import Authenticator
 | 
						|
from .config import Config, Section
 | 
						|
from .limiter import Limiter
 | 
						|
from .logging import ProgressBar, log
 | 
						|
from .output_dir import FileSink, OnConflict, OutputDirectory, Redownload
 | 
						|
from .transformer import RuleParseException, Transformer
 | 
						|
from .version import __version__
 | 
						|
 | 
						|
 | 
						|
class CrawlerLoadException(Exception):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
Wrapped = TypeVar("Wrapped", bound=Callable[..., None])
 | 
						|
 | 
						|
 | 
						|
def noncritical(f: Wrapped) -> Wrapped:
 | 
						|
    """
 | 
						|
    Warning: Must only be applied to member functions of the Crawler class!
 | 
						|
 | 
						|
    Catches all exceptions occuring during the function call. If an exception
 | 
						|
    occurs, the crawler's error_free variable is set to False.
 | 
						|
    """
 | 
						|
 | 
						|
    def wrapper(self: "Crawler", *args: Any, **kwargs: Any) -> None:
 | 
						|
        try:
 | 
						|
            f(self, *args, **kwargs)
 | 
						|
        except Exception as e:
 | 
						|
            log.print(f"[red]Something went wrong: {escape(str(e))}")
 | 
						|
            self.error_free = False
 | 
						|
    return wrapper  # type: ignore
 | 
						|
 | 
						|
 | 
						|
def repeat(attempts: int) -> Callable[[Wrapped], Wrapped]:
 | 
						|
    """
 | 
						|
    Warning: Must only be applied to member functions of the Crawler class!
 | 
						|
 | 
						|
    If an exception occurs during the function call, retries the function call
 | 
						|
    a set amount of times. Exceptions that occur during the last attempt are
 | 
						|
    not caught and instead passed on upwards.
 | 
						|
    """
 | 
						|
 | 
						|
    def decorator(f: Wrapped) -> Wrapped:
 | 
						|
        def wrapper(self: "Crawler", *args: Any, **kwargs: Any) -> None:
 | 
						|
            for _ in range(attempts - 1):
 | 
						|
                try:
 | 
						|
                    f(self, *args, **kwargs)
 | 
						|
                    return
 | 
						|
                except Exception:
 | 
						|
                    pass
 | 
						|
            f(self, *args, **kwargs)
 | 
						|
        return wrapper  # type: ignore
 | 
						|
    return decorator
 | 
						|
 | 
						|
 | 
						|
AWrapped = TypeVar("AWrapped", bound=Callable[..., Awaitable[None]])
 | 
						|
 | 
						|
 | 
						|
def anoncritical(f: AWrapped) -> AWrapped:
 | 
						|
    """
 | 
						|
    An async version of @noncritical.
 | 
						|
    Warning: Must only be applied to member functions of the Crawler class!
 | 
						|
 | 
						|
    Catches all exceptions occuring during the function call. If an exception
 | 
						|
    occurs, the crawler's error_free variable is set to False.
 | 
						|
    """
 | 
						|
 | 
						|
    async def wrapper(self: "Crawler", *args: Any, **kwargs: Any) -> None:
 | 
						|
        try:
 | 
						|
            await f(self, *args, **kwargs)
 | 
						|
        except Exception as e:
 | 
						|
            log.print(f"[red]Something went wrong: {escape(str(e))}")
 | 
						|
            self.error_free = False
 | 
						|
    return wrapper  # type: ignore
 | 
						|
 | 
						|
 | 
						|
def arepeat(attempts: int) -> Callable[[AWrapped], AWrapped]:
 | 
						|
    """
 | 
						|
    An async version of @noncritical.
 | 
						|
    Warning: Must only be applied to member functions of the Crawler class!
 | 
						|
 | 
						|
    If an exception occurs during the function call, retries the function call
 | 
						|
    a set amount of times. Exceptions that occur during the last attempt are
 | 
						|
    not caught and instead passed on upwards.
 | 
						|
    """
 | 
						|
 | 
						|
    def decorator(f: AWrapped) -> AWrapped:
 | 
						|
        async def wrapper(self: "Crawler", *args: Any, **kwargs: Any) -> None:
 | 
						|
            for _ in range(attempts - 1):
 | 
						|
                try:
 | 
						|
                    await f(self, *args, **kwargs)
 | 
						|
                    return
 | 
						|
                except Exception:
 | 
						|
                    pass
 | 
						|
            await f(self, *args, **kwargs)
 | 
						|
        return wrapper  # type: ignore
 | 
						|
    return decorator
 | 
						|
 | 
						|
 | 
						|
class CrawlerSection(Section):
 | 
						|
    def output_dir(self, name: str) -> Path:
 | 
						|
        # TODO Use removeprefix() after switching to 3.9
 | 
						|
        if name.startswith("crawl:"):
 | 
						|
            name = name[len("crawl:"):]
 | 
						|
        return Path(self.s.get("output_dir", name)).expanduser()
 | 
						|
 | 
						|
    def redownload(self) -> Redownload:
 | 
						|
        value = self.s.get("redownload", "never-smart")
 | 
						|
        try:
 | 
						|
            return Redownload.from_string(value)
 | 
						|
        except ValueError as e:
 | 
						|
            self.invalid_value(
 | 
						|
                "redownload",
 | 
						|
                value,
 | 
						|
                str(e).capitalize(),
 | 
						|
            )
 | 
						|
 | 
						|
    def on_conflict(self) -> OnConflict:
 | 
						|
        value = self.s.get("on_conflict", "prompt")
 | 
						|
        try:
 | 
						|
            return OnConflict.from_string(value)
 | 
						|
        except ValueError as e:
 | 
						|
            self.invalid_value(
 | 
						|
                "on_conflict",
 | 
						|
                value,
 | 
						|
                str(e).capitalize(),
 | 
						|
            )
 | 
						|
 | 
						|
    def transform(self) -> str:
 | 
						|
        return self.s.get("transform", "")
 | 
						|
 | 
						|
    def max_concurrent_tasks(self) -> int:
 | 
						|
        value = self.s.getint("max_concurrent_tasks", fallback=1)
 | 
						|
        if value <= 0:
 | 
						|
            self.invalid_value("max_concurrent_tasks", value,
 | 
						|
                               "Must be greater than 0")
 | 
						|
        return value
 | 
						|
 | 
						|
    def max_concurrent_downloads(self) -> int:
 | 
						|
        tasks = self.max_concurrent_tasks()
 | 
						|
        value = self.s.getint("max_concurrent_downloads", fallback=None)
 | 
						|
        if value is None:
 | 
						|
            return tasks
 | 
						|
        if value <= 0:
 | 
						|
            self.invalid_value("max_concurrent_downloads", value,
 | 
						|
                               "Must be greater than 0")
 | 
						|
        if value > tasks:
 | 
						|
            self.invalid_value("max_concurrent_downloads", value,
 | 
						|
                               "Must not be greater than max_concurrent_tasks")
 | 
						|
        return value
 | 
						|
 | 
						|
    def delay_between_tasks(self) -> float:
 | 
						|
        value = self.s.getfloat("delay_between_tasks", fallback=0.0)
 | 
						|
        if value < 0:
 | 
						|
            self.invalid_value("delay_between_tasks", value,
 | 
						|
                               "Must not be negative")
 | 
						|
        return value
 | 
						|
 | 
						|
    def auth(self, authenticators: Dict[str, Authenticator]) -> Authenticator:
 | 
						|
        value = self.s.get("auth")
 | 
						|
        if value is None:
 | 
						|
            self.missing_value("auth")
 | 
						|
        auth = authenticators.get(value)
 | 
						|
        if auth is None:
 | 
						|
            self.invalid_value("auth", value, "No such auth section exists")
 | 
						|
        return auth
 | 
						|
 | 
						|
 | 
						|
class Crawler(ABC):
 | 
						|
    def __init__(
 | 
						|
            self,
 | 
						|
            name: str,
 | 
						|
            section: CrawlerSection,
 | 
						|
            config: Config,
 | 
						|
    ) -> None:
 | 
						|
        """
 | 
						|
        Initialize a crawler from its name and its section in the config file.
 | 
						|
 | 
						|
        If you are writing your own constructor for your own crawler, make sure
 | 
						|
        to call this constructor first (via super().__init__).
 | 
						|
 | 
						|
        May throw a CrawlerLoadException.
 | 
						|
        """
 | 
						|
 | 
						|
        self.name = name
 | 
						|
        self.error_free = True
 | 
						|
 | 
						|
        self._limiter = Limiter(
 | 
						|
            task_limit=section.max_concurrent_tasks(),
 | 
						|
            download_limit=section.max_concurrent_downloads(),
 | 
						|
            task_delay=section.delay_between_tasks(),
 | 
						|
        )
 | 
						|
 | 
						|
        try:
 | 
						|
            self._transformer = Transformer(section.transform())
 | 
						|
        except RuleParseException as e:
 | 
						|
            e.pretty_print()
 | 
						|
            raise CrawlerLoadException()
 | 
						|
 | 
						|
        self._output_dir = OutputDirectory(
 | 
						|
            config.working_dir / section.output_dir(name),
 | 
						|
            section.redownload(),
 | 
						|
            section.on_conflict(),
 | 
						|
        )
 | 
						|
 | 
						|
    @asynccontextmanager
 | 
						|
    async def crawl_bar(
 | 
						|
            self,
 | 
						|
            path: PurePath,
 | 
						|
            total: Optional[int] = None,
 | 
						|
    ) -> AsyncIterator[ProgressBar]:
 | 
						|
        desc = f"[bold bright_cyan]Crawling[/] {escape(str(path))}"
 | 
						|
        async with self._limiter.limit_crawl():
 | 
						|
            with log.crawl_bar(desc, total=total) as bar:
 | 
						|
                yield bar
 | 
						|
 | 
						|
    @asynccontextmanager
 | 
						|
    async def download_bar(
 | 
						|
            self,
 | 
						|
            path: PurePath,
 | 
						|
            total: Optional[int] = None,
 | 
						|
    ) -> AsyncIterator[ProgressBar]:
 | 
						|
        desc = f"[bold bright_cyan]Downloading[/] {escape(str(path))}"
 | 
						|
        async with self._limiter.limit_download():
 | 
						|
            with log.download_bar(desc, total=total) as bar:
 | 
						|
                yield bar
 | 
						|
 | 
						|
    def should_crawl(self, path: PurePath) -> bool:
 | 
						|
        return self._transformer.transform(path) is not None
 | 
						|
 | 
						|
    async def download(
 | 
						|
            self,
 | 
						|
            path: PurePath,
 | 
						|
            mtime: Optional[datetime] = None,
 | 
						|
            redownload: Optional[Redownload] = None,
 | 
						|
            on_conflict: Optional[OnConflict] = None,
 | 
						|
    ) -> Optional[AsyncContextManager[FileSink]]:
 | 
						|
        transformed_path = self._transformer.transform(path)
 | 
						|
        if transformed_path is None:
 | 
						|
            return None
 | 
						|
 | 
						|
        return await self._output_dir.download(
 | 
						|
            transformed_path, mtime, redownload, on_conflict)
 | 
						|
 | 
						|
    async def cleanup(self) -> None:
 | 
						|
        await self._output_dir.cleanup()
 | 
						|
 | 
						|
    async def run(self) -> None:
 | 
						|
        """
 | 
						|
        Start the crawling process. Call this function if you want to use a
 | 
						|
        crawler.
 | 
						|
        """
 | 
						|
 | 
						|
        with log.show_progress():
 | 
						|
            await self.crawl()
 | 
						|
 | 
						|
    @abstractmethod
 | 
						|
    async def crawl(self) -> None:
 | 
						|
        """
 | 
						|
        Overwrite this function if you are writing a crawler.
 | 
						|
 | 
						|
        This function must not return before all crawling is complete. To crawl
 | 
						|
        multiple things concurrently, asyncio.gather can be used.
 | 
						|
        """
 | 
						|
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class HttpCrawler(Crawler):
 | 
						|
    COOKIE_FILE = PurePath(".cookies")
 | 
						|
 | 
						|
    def __init__(
 | 
						|
            self,
 | 
						|
            name: str,
 | 
						|
            section: CrawlerSection,
 | 
						|
            config: Config,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__(name, section, config)
 | 
						|
 | 
						|
        self._cookie_jar_path = self._output_dir.resolve(self.COOKIE_FILE)
 | 
						|
        self._output_dir.register_reserved(self.COOKIE_FILE)
 | 
						|
 | 
						|
    async def run(self) -> None:
 | 
						|
        cookie_jar = aiohttp.CookieJar()
 | 
						|
 | 
						|
        try:
 | 
						|
            cookie_jar.load(self._cookie_jar_path)
 | 
						|
        except Exception:
 | 
						|
            pass
 | 
						|
 | 
						|
        async with aiohttp.ClientSession(
 | 
						|
                headers={"User-Agent": f"pferd/{__version__}"},
 | 
						|
                cookie_jar=cookie_jar,
 | 
						|
        ) as session:
 | 
						|
            self.session = session
 | 
						|
            try:
 | 
						|
                await super().run()
 | 
						|
            finally:
 | 
						|
                del self.session
 | 
						|
 | 
						|
        try:
 | 
						|
            cookie_jar.save(self._cookie_jar_path)
 | 
						|
        except Exception:
 | 
						|
            log.print(f"[bold red]Warning:[/] Failed to save cookies to {escape(str(self.COOKIE_FILE))}")
 |