commit 269bb7a5923cdb1de69a4a3e42e05ab1bb4064c0 Author: Cavallium Date: Sun Nov 3 22:11:45 2019 +0100 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..818fabc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ + +config.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..fe2758f --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +outside-fetcher diff --git a/__main__.py b/__main__.py new file mode 100644 index 0000000..f3bd20c --- /dev/null +++ b/__main__.py @@ -0,0 +1,93 @@ +import typing +import pyrogram +import aiohttp.web +import asyncio +import multiprocessing +import config + + +from async_worker.async_worker import AsyncTaskScheduler +from tasks.channel_history import ChannelHistoryReadTask + + +class ClientRefCount: + client: pyrogram.Client + count: int = 0 + + def __init__(self, api_id: int, api_hash: str, name: str): + self.client = pyrogram.Client(name, api_id, api_hash) + + async def start(self): + await self.client.start() + + def get_client(self) -> pyrogram.Client: + self.count += 1 + return self.client + + def __gt__(self, other: 'ClientRefCount'): + return self.count < other.count + + +class ClientRoundRobin: + clients: typing.List[ClientRefCount] + + def __init__(self): + self.clients = [] + + def get_client(self) -> pyrogram.Client: + return min(self.clients).get_client() + + async def add_client(self, api_id: int, api_hash: str, name: str): + client = ClientRefCount(api_id, api_hash, name) + self.clients.append(client) + await client.start() + + +class FetcherAPI: + clients: ClientRoundRobin + scheduler: AsyncTaskScheduler + already_watching: typing.List[str] + + def __init__(self): + self.clients = ClientRoundRobin() + self.scheduler = AsyncTaskScheduler() + self.already_watching = [] + + async def setup(self): + for name in config.sessions: + await self.clients.add_client(config.app_id, config.app_hash, name) + + app = aiohttp.web.Application() + app.add_routes([aiohttp.web.get("/tasks/watch/add/{username}", self.add_channel)]) + + await asyncio.gather( + aiohttp.web._run_app(app, host=config.listen_host, port=config.listen_port), + *( + self.scheduler.loop() + for _ in range(multiprocessing.cpu_count()) + ) + ) + + async def add_channel(self, request): + user = request.match_info["username"] + + if user in self.already_watching: + return aiohttp.web.Response(status=400) + + self.already_watching.append(user) + + client = self.clients.get_client() + task = ChannelHistoryReadTask() + + peer = await client.resolve_peer(user) + task.setup(client, peer, config.webhook) + + await self.scheduler.submit(task) + return aiohttp.web.Response(status=200) + + +if __name__ == "__main__": + api = FetcherAPI() + + _loop = asyncio.get_event_loop() + _loop.run_until_complete(api.setup()) diff --git a/async_worker/__init__.py b/async_worker/__init__.py new file mode 100644 index 0000000..ebbca31 --- /dev/null +++ b/async_worker/__init__.py @@ -0,0 +1,6 @@ +from .async_worker import AsyncTaskScheduler, AsyncTask + +__all__ = [ + "AsyncTaskScheduler", + "AsyncTask" +] diff --git a/async_worker/async_worker.py b/async_worker/async_worker.py new file mode 100644 index 0000000..209c147 --- /dev/null +++ b/async_worker/async_worker.py @@ -0,0 +1,198 @@ +# MIT License +# +# Copyright (c) [2019] [andrew-ld] +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import asyncio +import abc +import typing +import time +import operator + + +class AsyncTask(abc.ABC): + _next: int + _locked: bool + + __slots__ = [ + "_next", + "_locked" + ] + + def __init__(self): + self._next = 0 + self._locked = False + + @abc.abstractmethod + async def process(self) -> int: + raise NotImplementedError + + @abc.abstractmethod + def setup(self, *args, **kwargs): + raise NotImplementedError + + def set_next(self, _next: int): + self._next = time.time() + _next + + def get_next(self) -> int: + return self._next + + def get_delay(self) -> int: + return self.get_next() - time.time() + + def lock(self): + self._locked = True + + def unlock(self): + self._locked = False + + def is_locked(self): + return self._locked + + +class AsyncTaskDelay: + _task: asyncio.Task + _delay_end: int + + __slots__ = [ + "_task", + "_delay_end", + ] + + def __init__(self): + self._task = asyncio.ensure_future(asyncio.Future()) + self._delay_end = 0 + + async def sleep(self, _time) -> bool: + self._delay_end = _time + time.time() + + self._task = asyncio.ensure_future( + asyncio.sleep(_time) + ) + + try: + + await self._task + + except asyncio.CancelledError: + return False + + return True + + def is_sleeping(self) -> bool: + return not (self._task.done() or self._task.cancelled()) + + def cancel(self): + self._task.cancel() + + def __gt__(self, other: 'AsyncTaskDelay'): + return self._delay_end > other._delay_end + + +class AsyncMultipleEvent: + _events: typing.List[asyncio.Event] + + __slots__ = [ + "_events" + ] + + def __init__(self): + self._events = [] + + async def lock(self): + event = asyncio.Event() + self._events.append(event) + await event.wait() + + def unlock_first(self): + if self._events: + self._events.pop(0).set() + + +class AsyncTaskScheduler: + _queue: typing.List[AsyncTask] + _wait_enqueue: AsyncMultipleEvent + _wait_unlock: AsyncMultipleEvent + _sleep_tasks: typing.List[AsyncTaskDelay] + + __slots__ = [ + "_queue", + "_sleep_tasks", + "_wait_enqueue", + "_wait_unlock" + ] + + def __init__(self): + self._queue = [] + self._sleep_tasks = [] + + self._wait_enqueue = AsyncMultipleEvent() + self._wait_unlock = AsyncMultipleEvent() + + async def submit(self, task: AsyncTask): + self._queue.append(task) + self._wait_enqueue.unlock_first() + self._wait_unlock.unlock_first() + + cancellable_tasks = [*filter(lambda x: x.is_sleeping(), self._sleep_tasks)] + + if cancellable_tasks: + max(cancellable_tasks).cancel() + + async def loop(self): + sleeper = AsyncTaskDelay() + self._sleep_tasks.append(sleeper) + + while True: + if not self._queue: + await self._wait_enqueue.lock() + + while self._queue: + runnable_tasks = [*filter(lambda x: not x.is_locked(), self._queue)] + + if not runnable_tasks: + await self._wait_unlock.lock() + continue + + task, delay = min( + ( + (task, task.get_next()) + for task in runnable_tasks + ), + + key=operator.itemgetter(1) + ) + + delay -= time.time() + task.lock() + + if delay > 0 and not await sleeper.sleep(delay): + task.unlock() + continue + + next_delay = await task.process() + + if next_delay is False: + self._queue.remove(task) + + else: + task.set_next(next_delay) + task.unlock() + self._wait_unlock.unlock_first() diff --git a/config.py.example b/config.py.example new file mode 100644 index 0000000..3c7d597 --- /dev/null +++ b/config.py.example @@ -0,0 +1,13 @@ +app_id = 6 +app_hash = "eb06d4abfb49dc3eeb1aeb98ae0f581e" +sessions = ["sessions/test1", "sessions/test2"] + +listen_host = "127.0.0.1" +listen_port = 8080 + +webhook = "https://example.org/webhook" + + + +# Usage: +# http://listen_host:listen_port/tasks/watch/add/{groupname} diff --git a/tasks/__init__.py b/tasks/__init__.py new file mode 100644 index 0000000..2b35e7f --- /dev/null +++ b/tasks/__init__.py @@ -0,0 +1,5 @@ +from .channel_history import ChannelHistoryReadTask + +__all__ = [ + "ChannelHistoryReadTask" +] diff --git a/tasks/channel_history.py b/tasks/channel_history.py new file mode 100644 index 0000000..f7a545e --- /dev/null +++ b/tasks/channel_history.py @@ -0,0 +1,58 @@ +import pyrogram +import aiohttp + +from async_worker.async_worker import AsyncTask +from pyrogram.api.types import ChannelMessagesFilterEmpty +from pyrogram.api.types.updates import ChannelDifference, ChannelDifferenceTooLong, ChannelDifferenceEmpty +from pyrogram.api.functions.updates.get_channel_difference import GetChannelDifference + +from pyrogram.client.types.messages_and_media import Message as MessagePyrogram + + +class ChannelHistoryReadTask(AsyncTask): + channel: pyrogram.Chat + client: pyrogram.Client + pts: int + webhook: str + + def setup(self, client: pyrogram.Client, channel: pyrogram.Chat, webhook: str): + self.client = client + self.channel = channel + self.pts = False + self.webhook = webhook + + async def process(self): + response = await self.client.send( + GetChannelDifference( + channel=self.channel, + filter=ChannelMessagesFilterEmpty(), + pts=self.pts if self.pts else 0xFFFFFFF, + limit=0xFFFFFFF, + force=True + ) + ) + + if isinstance(response, ChannelDifference): + self.pts = response.pts + + users = {i.id: i for i in response.users} + chats = {i.id: i for i in response.chats} + http = aiohttp.ClientSession() + + for message in response.new_messages: + message = await MessagePyrogram._parse(self.client, message, users, chats) + await http.post(self.webhook, data=bytes(str(message), "utf8")) + + await http.close() + + if not response.final: + return 1 + + return response.timeout + + if isinstance(response, ChannelDifferenceEmpty): + self.pts = response.pts + return response.timeout + + if isinstance(response, ChannelDifferenceTooLong): + return False