diff --git a/tasks/__init__.py b/tasks/__init__.py index 2b35e7f..3bb5182 100644 --- a/tasks/__init__.py +++ b/tasks/__init__.py @@ -1,5 +1,7 @@ from .channel_history import ChannelHistoryReadTask +from .mtproto_task_abstraction import MtProtoTask __all__ = [ - "ChannelHistoryReadTask" + "ChannelHistoryReadTask", + "MtProtoTask" ] diff --git a/tasks/channel_history.py b/tasks/channel_history.py index d3cbba8..36184d7 100644 --- a/tasks/channel_history.py +++ b/tasks/channel_history.py @@ -1,8 +1,10 @@ import pyrogram import aiohttp import json +import typing + +from .mtproto_task_abstraction import MtProtoTask -from async_worker.async_worker import AsyncTask from pyrogram.api.types import ChannelMessagesFilterEmpty from pyrogram.api.types.updates import ChannelDifference, ChannelDifferenceTooLong, ChannelDifferenceEmpty from pyrogram.api.functions.updates.get_channel_difference import GetChannelDifference @@ -58,7 +60,7 @@ class JsonSerializer: return result -class ChannelHistoryReadTask(AsyncTask): +class ChannelHistoryReadTask(MtProtoTask): channel: pyrogram.Chat client: pyrogram.Client pts: int @@ -72,7 +74,7 @@ class ChannelHistoryReadTask(AsyncTask): self.webhook = webhook self.http = aiohttp.ClientSession() - async def process(self): + async def mt_process(self) -> typing.Union[bool, int]: response = await self.client.send( GetChannelDifference( channel=self.channel, diff --git a/tasks/mtproto_task_abstraction.py b/tasks/mtproto_task_abstraction.py new file mode 100644 index 0000000..24cb88e --- /dev/null +++ b/tasks/mtproto_task_abstraction.py @@ -0,0 +1,27 @@ +import abc +import typing + +from async_worker import AsyncTask + +from pyrogram.errors.exceptions import FloodWait, ChannelInvalid, ChannelPrivate, Unauthorized + + +class MtProtoTask(AsyncTask, abc.ABC): + @abc.abstractmethod + def setup(self, *args, **kwargs): + raise NotImplementedError + + @abc.abstractmethod + async def mt_process(self) -> typing.Union[bool, int]: + raise NotImplementedError + + async def process(self) -> int: + try: + + return await self.mt_process() + + except FloodWait as error: + return int(error.MESSAGE.split("_")[-1]) + + except (ChannelInvalid, ChannelPrivate, Unauthorized): + return False