From eaba19a07862317338c6382e92f51131302ce7be Mon Sep 17 00:00:00 2001 From: "andrew (from workstation)" Date: Tue, 5 Nov 2019 19:48:14 +0100 Subject: [PATCH] schedule webhook data forwarding --- tasks/__init__.py | 4 ++- tasks/channel_history.py | 51 +++++++++++++------------------ tasks/mtproto_task_abstraction.py | 9 +++--- tasks/webhook.py | 24 +++++++++++++++ 4 files changed, 52 insertions(+), 36 deletions(-) create mode 100644 tasks/webhook.py diff --git a/tasks/__init__.py b/tasks/__init__.py index 3bb5182..7dbcb0f 100644 --- a/tasks/__init__.py +++ b/tasks/__init__.py @@ -1,7 +1,9 @@ from .channel_history import ChannelHistoryReadTask from .mtproto_task_abstraction import MtProtoTask +from .webhook import WebHookDataForward __all__ = [ "ChannelHistoryReadTask", - "MtProtoTask" + "MtProtoTask", + "WebHookDataForward" ] diff --git a/tasks/channel_history.py b/tasks/channel_history.py index 36184d7..d051cca 100644 --- a/tasks/channel_history.py +++ b/tasks/channel_history.py @@ -1,9 +1,9 @@ import pyrogram -import aiohttp import json import typing from .mtproto_task_abstraction import MtProtoTask +from .webhook import WebHookDataForward from pyrogram.api.types import ChannelMessagesFilterEmpty from pyrogram.api.types.updates import ChannelDifference, ChannelDifferenceTooLong, ChannelDifferenceEmpty @@ -61,42 +61,40 @@ class JsonSerializer: class ChannelHistoryReadTask(MtProtoTask): - channel: pyrogram.Chat - client: pyrogram.Client - pts: int - webhook: str - http: aiohttp.ClientSession + _channel: pyrogram.Chat + _client: pyrogram.Client + _pts: int + _webhook: str def setup(self, client: pyrogram.Client, channel: pyrogram.Chat, webhook: str): - self.client = client - self.channel = channel - self.pts = False - self.webhook = webhook - self.http = aiohttp.ClientSession() + self._pts = False - async def mt_process(self) -> typing.Union[bool, int]: - response = await self.client.send( + self._client = client + self._channel = channel + self._webhook = webhook + + async def process(self) -> typing.Union[bool, int]: + response = await self._client.send( GetChannelDifference( - channel=self.channel, + channel=self._channel, filter=ChannelMessagesFilterEmpty(), - pts=self.pts if self.pts else 0xFFFFFFF, + pts=self._pts if self._pts else 0xFFFFFFF, limit=0xFFFFFFF, force=True ) ) if isinstance(response, ChannelDifference): - self.pts = response.pts + self._pts = response.pts users = {i.id: i for i in response.users} chats = {i.id: i for i in response.chats} for message in response.new_messages: - message = await MessagePyrogram._parse(self.client, message, users, chats) - message = {"update_id": 1, "message": message} + message = await MessagePyrogram._parse(self._client, message, users, chats) data = json.dumps( - message, + {"update_id": 1, "message": message}, default=JsonSerializer.default, ensure_ascii=True, allow_nan=False, @@ -104,16 +102,9 @@ class ChannelHistoryReadTask(MtProtoTask): sort_keys=False ) - result = await self.http.post( - self.webhook, - data=data, - headers=[ - ("Content-Type", "application/json") - ] - ) - - await result.read() - result.close() + forwarder = WebHookDataForward() + forwarder.setup(self._webhook, data) + await self.future(forwarder) if not response.final: return 1 @@ -121,7 +112,7 @@ class ChannelHistoryReadTask(MtProtoTask): return response.timeout if isinstance(response, ChannelDifferenceEmpty): - self.pts = response.pts + self._pts = response.pts return response.timeout if isinstance(response, ChannelDifferenceTooLong): diff --git a/tasks/mtproto_task_abstraction.py b/tasks/mtproto_task_abstraction.py index 24cb88e..e959671 100644 --- a/tasks/mtproto_task_abstraction.py +++ b/tasks/mtproto_task_abstraction.py @@ -12,16 +12,15 @@ class MtProtoTask(AsyncTask, abc.ABC): raise NotImplementedError @abc.abstractmethod - async def mt_process(self) -> typing.Union[bool, int]: + async def process(self) -> typing.Union[bool, int]: raise NotImplementedError - async def process(self) -> int: + async def _process(self) -> typing.Union[bool, int]: try: - - return await self.mt_process() + return await self.process() * 1e9 except FloodWait as error: - return int(error.MESSAGE.split("_")[-1]) + return int(error.MESSAGE.split("_")[-1]) * 1e9 except (ChannelInvalid, ChannelPrivate, Unauthorized): return False diff --git a/tasks/webhook.py b/tasks/webhook.py new file mode 100644 index 0000000..b072c95 --- /dev/null +++ b/tasks/webhook.py @@ -0,0 +1,24 @@ +import typing +import aiohttp + +from async_worker import OneLoopAsyncTask + + +HEADERS = [("Content-Type", "application/json")] + + +class WebHookDataForward(OneLoopAsyncTask): + _webhook: str + _data: typing.Union[str, bytes] + _http: aiohttp.ClientSession + + async def process(self) -> typing.NoReturn: + res = await self._http.post(self._webhook, data=self._data, headers=HEADERS) + await res.read() + res.close() + + def setup(self, webhook: str, data: typing.Union[str, bytes]): + self._http = aiohttp.ClientSession() + + self._webhook = webhook + self._data = data