mirror of
https://github.com/Garmelon/PFERD.git
synced 2023-12-21 10:23:01 +01:00
Change limiter logic
Now download tasks are a subset of all tasks.
This commit is contained in:
@ -2,7 +2,7 @@ import asyncio
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncContextManager, AsyncIterator, Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -11,15 +11,27 @@ class Slot:
|
||||
last_left: Optional[float] = None
|
||||
|
||||
|
||||
class SlotPool:
|
||||
def __init__(self, limit: int, delay: float):
|
||||
if limit <= 0:
|
||||
raise ValueError("limit must be greater than 0")
|
||||
class Limiter:
|
||||
def __init__(
|
||||
self,
|
||||
task_limit: int,
|
||||
download_limit: int,
|
||||
task_delay: float
|
||||
):
|
||||
if task_limit <= 0:
|
||||
raise ValueError("task limit must be at least 1")
|
||||
if download_limit <= 0:
|
||||
raise ValueError("download limit must be at least 1")
|
||||
if download_limit > task_limit:
|
||||
raise ValueError("download limit can't be greater than task limit")
|
||||
if task_delay < 0:
|
||||
raise ValueError("Task delay must not be negative")
|
||||
|
||||
self._slots = [Slot() for _ in range(limit)]
|
||||
self._delay = delay
|
||||
self._slots = [Slot() for _ in range(task_limit)]
|
||||
self._downloads = download_limit
|
||||
self._delay = task_delay
|
||||
|
||||
self._free = asyncio.Condition()
|
||||
self._condition = asyncio.Condition()
|
||||
|
||||
def _acquire_slot(self) -> Optional[Slot]:
|
||||
for slot in self._slots:
|
||||
@ -29,40 +41,57 @@ class SlotPool:
|
||||
|
||||
return None
|
||||
|
||||
def _release_slot(self, slot: Slot) -> None:
|
||||
slot.last_left = time.time()
|
||||
slot.active = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def limit(self) -> AsyncIterator[None]:
|
||||
slot: Slot
|
||||
async with self._free:
|
||||
while True:
|
||||
if found_slot := self._acquire_slot():
|
||||
slot = found_slot
|
||||
break
|
||||
await self._free.wait()
|
||||
|
||||
async def _wait_for_slot_delay(self, slot: Slot) -> None:
|
||||
if slot.last_left is not None:
|
||||
delay = slot.last_left + self._delay - time.time()
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
def _release_slot(self, slot: Slot) -> None:
|
||||
slot.last_left = time.time()
|
||||
slot.active = False
|
||||
|
||||
@asynccontextmanager
|
||||
async def limit_crawl(self) -> AsyncIterator[None]:
|
||||
slot: Slot
|
||||
async with self._condition:
|
||||
while True:
|
||||
if found_slot := self._acquire_slot():
|
||||
slot = found_slot
|
||||
break
|
||||
await self._condition.wait()
|
||||
|
||||
await self._wait_for_slot_delay(slot)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
async with self._free:
|
||||
async with self._condition:
|
||||
self._release_slot(slot)
|
||||
self._free.notify()
|
||||
self._condition.notify_all()
|
||||
|
||||
@asynccontextmanager
|
||||
async def limit_download(self) -> AsyncIterator[None]:
|
||||
slot: Slot
|
||||
async with self._condition:
|
||||
while True:
|
||||
if self._downloads <= 0:
|
||||
await self._condition.wait()
|
||||
continue
|
||||
|
||||
class Limiter:
|
||||
def __init__(self, crawl_limit: int, download_limit: int, delay: float):
|
||||
self._crawl_pool = SlotPool(crawl_limit, delay)
|
||||
self._download_pool = SlotPool(download_limit, delay)
|
||||
if found_slot := self._acquire_slot():
|
||||
slot = found_slot
|
||||
self._downloads -= 1
|
||||
break
|
||||
|
||||
def limit_crawl(self) -> AsyncContextManager[None]:
|
||||
return self._crawl_pool.limit()
|
||||
await self._condition.wait()
|
||||
|
||||
def limit_download(self) -> AsyncContextManager[None]:
|
||||
return self._crawl_pool.limit()
|
||||
await self._wait_for_slot_delay(slot)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
async with self._condition:
|
||||
self._release_slot(slot)
|
||||
self._downloads += 1
|
||||
self._condition.notify_all()
|
||||
|
Reference in New Issue
Block a user