diff --git a/__main__.py b/__main__.py index f3bd20c..84e2eab 100644 --- a/__main__.py +++ b/__main__.py @@ -77,10 +77,7 @@ class FetcherAPI: self.already_watching.append(user) client = self.clients.get_client() - task = ChannelHistoryReadTask() - - peer = await client.resolve_peer(user) - task.setup(client, peer, config.webhook) + task = ChannelHistoryReadTask(client, user, config.webhook) await self.scheduler.submit(task) return aiohttp.web.Response(status=200) diff --git a/async_worker/async_worker.py b/async_worker/async_worker.py index da9f18d..e26e0e2 100644 --- a/async_worker/async_worker.py +++ b/async_worker/async_worker.py @@ -36,12 +36,14 @@ class AsyncTask(abc.ABC): __slots__ = [ "_next", "_locked", - "_scheduler" + "_scheduler", + "_ready" ] def __init__(self): self._next = 0 self._locked = False + self._ready = False def set_scheduler(self, scheduler: 'AsyncTaskScheduler'): self._scheduler = scheduler @@ -64,9 +66,17 @@ class AsyncTask(abc.ABC): async def process(self) -> typing.Union[bool, int]: raise NotImplementedError - @abc.abstractmethod - def setup(self, *args, **kwargs): - raise NotImplementedError + async def _setup(self) -> bool: + result = await self.setup() + self._ready = True + + if result is None: + return True + + return result + + async def setup(self) -> bool: + pass def set_next(self, _next: int): self._next = time.time_ns() + _next @@ -77,6 +87,9 @@ class AsyncTask(abc.ABC): def get_delay(self) -> int: return self.get_next() - time.time_ns() + def is_ready(self) -> bool: + return self._ready + def lock(self): self._locked = True @@ -164,8 +177,8 @@ class SchedulerConfig: ] def __init__(self, - imprecise_delay: int = 2 * 1e+8, - skippable_delay: int = 3 * 1e+8, + imprecise_delay: int = 2 * 1e8, + skippable_delay: int = 3 * 1e8, max_fast_submit_tasks: int = 50 ): self.imprecise_delay = imprecise_delay @@ -239,29 +252,28 @@ class AsyncTaskScheduler: if not runnable_tasks: await self._wait_unlock.lock() - await asyncio.sleep(0) continue - fast_submit_tasks = [*filter(lambda x: x.get_delay() <= self._config.imprecise_delay, runnable_tasks)] + submittable = [*filter(lambda x: x.get_delay() <= self._config.imprecise_delay, runnable_tasks)] - if fast_submit_tasks: - for task in fast_submit_tasks: + if submittable: + + for task in submittable: task.lock() task.set_scheduler(self) - while fast_submit_tasks: + while submittable: futures = [] - for task in fast_submit_tasks[:self._config.max_fast_submit_tasks]: + for task in submittable[:self._config.max_fast_submit_tasks]: on_done = functools.partial(on_complete, task, self._queue, self._wait_unlock) + future = asyncio.ensure_future(task._process() if task.is_ready() else task._setup()) - future = asyncio.ensure_future(task._process()) future.add_done_callback(on_done) - futures.append(future) await asyncio.gather(*futures) - fast_submit_tasks = fast_submit_tasks[self._config.max_fast_submit_tasks:] + submittable = submittable[self._config.max_fast_submit_tasks:] continue diff --git a/tasks/executors/mtproto_task_abstraction.py b/tasks/executors/mtproto_task_abstraction.py index d55ec13..cb11071 100644 --- a/tasks/executors/mtproto_task_abstraction.py +++ b/tasks/executors/mtproto_task_abstraction.py @@ -3,7 +3,7 @@ import typing from async_worker import AsyncTask -from pyrogram.errors.exceptions import FloodWait, RPCError +from pyrogram.errors.exceptions import FloodWait, RPCError, InternalServerError class MtProtoTask(AsyncTask, abc.ABC): @@ -14,8 +14,11 @@ class MtProtoTask(AsyncTask, abc.ABC): except FloodWait as error: return int(error.MESSAGE.split("_")[-1]) * 1e9 - except RPCError: - return False + except InternalServerError: + return 1e9 + + except RPCError as error: + return False if abs(error.CODE) < 500 else 1e9 * 2 async def _setup(self) -> bool: return await self._execute(super()._setup) diff --git a/tasks/webhook.py b/tasks/webhook.py index bb9537e..2bca2c9 100644 --- a/tasks/webhook.py +++ b/tasks/webhook.py @@ -17,9 +17,11 @@ class WebHookDataForward(OneLoopAsyncTask): await res.read() res.close() + async def setup(self) -> bool: + self._http = aiohttp.ClientSession() + def __init__(self, webhook: str, data: typing.Union[str, bytes]): super().__init__() - self._http = aiohttp.ClientSession() self._webhook = webhook self._data = data