async init using scheduler

This commit is contained in:
andrew (from workstation) 2019-11-06 15:53:46 +01:00
parent a570afedde
commit 89c8575634
8 changed files with 110 additions and 102 deletions

View File

@ -1,9 +1,7 @@
from .channel_history import ChannelHistoryReadTask
from .mtproto_task_abstraction import MtProtoTask
from .webhook import WebHookDataForward
__all__ = [
"ChannelHistoryReadTask",
"MtProtoTask",
"WebHookDataForward"
]

View File

@ -1,78 +1,32 @@
import pyrogram
import json
import typing
from .mtproto_task_abstraction import MtProtoTask
from .executors import MtProtoTask
from .webhook import WebHookDataForward
from .reformat import message_to_bot_api
from pyrogram.api.types import ChannelMessagesFilterEmpty
from pyrogram.api.types.updates import ChannelDifference, ChannelDifferenceTooLong, ChannelDifferenceEmpty
from pyrogram.api.functions.updates.get_channel_difference import GetChannelDifference
from pyrogram.client.types.messages_and_media import Message as MessagePyrogram
class JsonSerializerFixes:
@staticmethod
def user(obj):
obj.type = "private"
return obj
@staticmethod
def user_type(obj):
obj.type = "channel" if obj.id < 0 else "private"
return obj
class JsonSerializer:
fixes = {
"from_user": {
"new_name": "from",
"patch": JsonSerializerFixes.user
},
"user": {
"new_name": "user",
"patch": JsonSerializerFixes.user_type
}
}
@staticmethod
def default(obj):
if isinstance(obj, bytes):
return repr(obj)
cls = JsonSerializer
result = {}
for name in filter(lambda x: not x.startswith("_"), obj.__dict__):
value = getattr(obj, name)
if value is None:
continue
if name in cls.fixes:
value = cls.fixes[name]["patch"](value)
name = cls.fixes[name]["new_name"]
result[name] = value
return result
class ChannelHistoryReadTask(MtProtoTask):
_channel: pyrogram.Chat
_channel: typing.Union[pyrogram.Chat, str]
_client: pyrogram.Client
_pts: int
_webhook: str
def setup(self, client: pyrogram.Client, channel: pyrogram.Chat, webhook: str):
def __init__(self, client: pyrogram.Client, channel: str, webhook: str):
super().__init__()
self._pts = False
self._client = client
self._channel = channel
self._webhook = webhook
async def setup(self):
self._channel = await self._client.resolve_peer(self._channel)
async def process(self) -> typing.Union[bool, int]:
response = await self._client.send(
GetChannelDifference(
@ -91,19 +45,8 @@ class ChannelHistoryReadTask(MtProtoTask):
chats = {i.id: i for i in response.chats}
for message in response.new_messages:
message = await MessagePyrogram._parse(self._client, message, users, chats)
data = json.dumps(
{"update_id": 1, "message": message},
default=JsonSerializer.default,
ensure_ascii=True,
allow_nan=False,
check_circular=True,
sort_keys=False
)
forwarder = WebHookDataForward()
forwarder.setup(self._webhook, data)
data = await message_to_bot_api(self._client, users, chats, message)
forwarder = WebHookDataForward(self._webhook, data)
await self.future(forwarder)
if not response.final:

View File

@ -0,0 +1,5 @@
from tasks.executors.mtproto_task_abstraction import MtProtoTask
__all__ = [
"MtProtoTask"
]

View File

@ -0,0 +1,24 @@
import abc
import typing
from async_worker import AsyncTask
from pyrogram.errors.exceptions import FloodWait, RPCError
class MtProtoTask(AsyncTask, abc.ABC):
async def _execute(self, func: typing.Callable) -> typing.Any:
try:
return await func()
except FloodWait as error:
return int(error.MESSAGE.split("_")[-1]) * 1e9
except RPCError:
return False
async def _setup(self) -> bool:
return await self._execute(super()._setup)
async def _process(self) -> typing.Union[bool, int]:
return await self._execute(super()._process)

View File

@ -1,32 +0,0 @@
import abc
import typing
from async_worker import AsyncTask
from pyrogram.errors.exceptions import FloodWait, ChannelInvalid, ChannelPrivate, Unauthorized
class MtProtoTask(AsyncTask, abc.ABC):
@abc.abstractmethod
def setup(self, *args, **kwargs):
raise NotImplementedError
@abc.abstractmethod
async def process(self) -> typing.Union[bool, int]:
raise NotImplementedError
async def _process(self) -> typing.Union[bool, int]:
try:
result = await self.process()
if result is False:
return False
return result * 1e9
except FloodWait as error:
return int(error.MESSAGE.split("_")[-1]) * 1e9
except (ChannelInvalid, ChannelPrivate, Unauthorized):
return False

View File

@ -0,0 +1,5 @@
from .mtproto_bot_api import message_to_bot_api
__all__ = [
"message_to_bot_api"
]

View File

@ -0,0 +1,64 @@
import json
from pyrogram.client.types.messages_and_media import Message
class JsonSerializerFixes:
@staticmethod
def user(obj):
obj.type = "private"
return obj
@staticmethod
def user_type(obj):
obj.type = "channel" if obj.id < 0 else "private"
return obj
class JsonSerializer:
fixes = {
"from_user": {
"new_name": "from",
"patch": JsonSerializerFixes.user
},
"user": {
"new_name": "user",
"patch": JsonSerializerFixes.user_type
}
}
@staticmethod
def default(obj):
if isinstance(obj, bytes):
return repr(obj)
cls = JsonSerializer
result = {}
for name in filter(lambda x: not x.startswith("_"), obj.__dict__):
value = getattr(obj, name)
if value is None:
continue
if name in cls.fixes:
value = cls.fixes[name]["patch"](value)
name = cls.fixes[name]["new_name"]
result[name] = value
return result
async def message_to_bot_api(client, users, chats, message) -> str:
message = await Message._parse(client, message, users, chats)
return json.dumps(
{"update_id": 1, "message": message},
default=JsonSerializer.default,
ensure_ascii=True,
allow_nan=False,
check_circular=True,
sort_keys=False
)

View File

@ -17,7 +17,8 @@ class WebHookDataForward(OneLoopAsyncTask):
await res.read()
res.close()
def setup(self, webhook: str, data: typing.Union[str, bytes]):
def __init__(self, webhook: str, data: typing.Union[str, bytes]):
super().__init__()
self._http = aiohttp.ClientSession()
self._webhook = webhook