Initial commit
This commit is contained in:
commit
269bb7a592
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
|
||||
config.py
|
93
__main__.py
Normal file
93
__main__.py
Normal file
@ -0,0 +1,93 @@
|
||||
import typing
|
||||
import pyrogram
|
||||
import aiohttp.web
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import config
|
||||
|
||||
|
||||
from async_worker.async_worker import AsyncTaskScheduler
|
||||
from tasks.channel_history import ChannelHistoryReadTask
|
||||
|
||||
|
||||
class ClientRefCount:
|
||||
client: pyrogram.Client
|
||||
count: int = 0
|
||||
|
||||
def __init__(self, api_id: int, api_hash: str, name: str):
|
||||
self.client = pyrogram.Client(name, api_id, api_hash)
|
||||
|
||||
async def start(self):
|
||||
await self.client.start()
|
||||
|
||||
def get_client(self) -> pyrogram.Client:
|
||||
self.count += 1
|
||||
return self.client
|
||||
|
||||
def __gt__(self, other: 'ClientRefCount'):
|
||||
return self.count < other.count
|
||||
|
||||
|
||||
class ClientRoundRobin:
|
||||
clients: typing.List[ClientRefCount]
|
||||
|
||||
def __init__(self):
|
||||
self.clients = []
|
||||
|
||||
def get_client(self) -> pyrogram.Client:
|
||||
return min(self.clients).get_client()
|
||||
|
||||
async def add_client(self, api_id: int, api_hash: str, name: str):
|
||||
client = ClientRefCount(api_id, api_hash, name)
|
||||
self.clients.append(client)
|
||||
await client.start()
|
||||
|
||||
|
||||
class FetcherAPI:
|
||||
clients: ClientRoundRobin
|
||||
scheduler: AsyncTaskScheduler
|
||||
already_watching: typing.List[str]
|
||||
|
||||
def __init__(self):
|
||||
self.clients = ClientRoundRobin()
|
||||
self.scheduler = AsyncTaskScheduler()
|
||||
self.already_watching = []
|
||||
|
||||
async def setup(self):
|
||||
for name in config.sessions:
|
||||
await self.clients.add_client(config.app_id, config.app_hash, name)
|
||||
|
||||
app = aiohttp.web.Application()
|
||||
app.add_routes([aiohttp.web.get("/tasks/watch/add/{username}", self.add_channel)])
|
||||
|
||||
await asyncio.gather(
|
||||
aiohttp.web._run_app(app, host=config.listen_host, port=config.listen_port),
|
||||
*(
|
||||
self.scheduler.loop()
|
||||
for _ in range(multiprocessing.cpu_count())
|
||||
)
|
||||
)
|
||||
|
||||
async def add_channel(self, request):
|
||||
user = request.match_info["username"]
|
||||
|
||||
if user in self.already_watching:
|
||||
return aiohttp.web.Response(status=400)
|
||||
|
||||
self.already_watching.append(user)
|
||||
|
||||
client = self.clients.get_client()
|
||||
task = ChannelHistoryReadTask()
|
||||
|
||||
peer = await client.resolve_peer(user)
|
||||
task.setup(client, peer, config.webhook)
|
||||
|
||||
await self.scheduler.submit(task)
|
||||
return aiohttp.web.Response(status=200)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = FetcherAPI()
|
||||
|
||||
_loop = asyncio.get_event_loop()
|
||||
_loop.run_until_complete(api.setup())
|
6
async_worker/__init__.py
Normal file
6
async_worker/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .async_worker import AsyncTaskScheduler, AsyncTask
|
||||
|
||||
__all__ = [
|
||||
"AsyncTaskScheduler",
|
||||
"AsyncTask"
|
||||
]
|
198
async_worker/async_worker.py
Normal file
198
async_worker/async_worker.py
Normal file
@ -0,0 +1,198 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) [2019] [andrew-ld]
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import asyncio
|
||||
import abc
|
||||
import typing
|
||||
import time
|
||||
import operator
|
||||
|
||||
|
||||
class AsyncTask(abc.ABC):
|
||||
_next: int
|
||||
_locked: bool
|
||||
|
||||
__slots__ = [
|
||||
"_next",
|
||||
"_locked"
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._next = 0
|
||||
self._locked = False
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_next(self, _next: int):
|
||||
self._next = time.time() + _next
|
||||
|
||||
def get_next(self) -> int:
|
||||
return self._next
|
||||
|
||||
def get_delay(self) -> int:
|
||||
return self.get_next() - time.time()
|
||||
|
||||
def lock(self):
|
||||
self._locked = True
|
||||
|
||||
def unlock(self):
|
||||
self._locked = False
|
||||
|
||||
def is_locked(self):
|
||||
return self._locked
|
||||
|
||||
|
||||
class AsyncTaskDelay:
|
||||
_task: asyncio.Task
|
||||
_delay_end: int
|
||||
|
||||
__slots__ = [
|
||||
"_task",
|
||||
"_delay_end",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._task = asyncio.ensure_future(asyncio.Future())
|
||||
self._delay_end = 0
|
||||
|
||||
async def sleep(self, _time) -> bool:
|
||||
self._delay_end = _time + time.time()
|
||||
|
||||
self._task = asyncio.ensure_future(
|
||||
asyncio.sleep(_time)
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
await self._task
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return not (self._task.done() or self._task.cancelled())
|
||||
|
||||
def cancel(self):
|
||||
self._task.cancel()
|
||||
|
||||
def __gt__(self, other: 'AsyncTaskDelay'):
|
||||
return self._delay_end > other._delay_end
|
||||
|
||||
|
||||
class AsyncMultipleEvent:
|
||||
_events: typing.List[asyncio.Event]
|
||||
|
||||
__slots__ = [
|
||||
"_events"
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._events = []
|
||||
|
||||
async def lock(self):
|
||||
event = asyncio.Event()
|
||||
self._events.append(event)
|
||||
await event.wait()
|
||||
|
||||
def unlock_first(self):
|
||||
if self._events:
|
||||
self._events.pop(0).set()
|
||||
|
||||
|
||||
class AsyncTaskScheduler:
|
||||
_queue: typing.List[AsyncTask]
|
||||
_wait_enqueue: AsyncMultipleEvent
|
||||
_wait_unlock: AsyncMultipleEvent
|
||||
_sleep_tasks: typing.List[AsyncTaskDelay]
|
||||
|
||||
__slots__ = [
|
||||
"_queue",
|
||||
"_sleep_tasks",
|
||||
"_wait_enqueue",
|
||||
"_wait_unlock"
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._queue = []
|
||||
self._sleep_tasks = []
|
||||
|
||||
self._wait_enqueue = AsyncMultipleEvent()
|
||||
self._wait_unlock = AsyncMultipleEvent()
|
||||
|
||||
async def submit(self, task: AsyncTask):
|
||||
self._queue.append(task)
|
||||
self._wait_enqueue.unlock_first()
|
||||
self._wait_unlock.unlock_first()
|
||||
|
||||
cancellable_tasks = [*filter(lambda x: x.is_sleeping(), self._sleep_tasks)]
|
||||
|
||||
if cancellable_tasks:
|
||||
max(cancellable_tasks).cancel()
|
||||
|
||||
async def loop(self):
|
||||
sleeper = AsyncTaskDelay()
|
||||
self._sleep_tasks.append(sleeper)
|
||||
|
||||
while True:
|
||||
if not self._queue:
|
||||
await self._wait_enqueue.lock()
|
||||
|
||||
while self._queue:
|
||||
runnable_tasks = [*filter(lambda x: not x.is_locked(), self._queue)]
|
||||
|
||||
if not runnable_tasks:
|
||||
await self._wait_unlock.lock()
|
||||
continue
|
||||
|
||||
task, delay = min(
|
||||
(
|
||||
(task, task.get_next())
|
||||
for task in runnable_tasks
|
||||
),
|
||||
|
||||
key=operator.itemgetter(1)
|
||||
)
|
||||
|
||||
delay -= time.time()
|
||||
task.lock()
|
||||
|
||||
if delay > 0 and not await sleeper.sleep(delay):
|
||||
task.unlock()
|
||||
continue
|
||||
|
||||
next_delay = await task.process()
|
||||
|
||||
if next_delay is False:
|
||||
self._queue.remove(task)
|
||||
|
||||
else:
|
||||
task.set_next(next_delay)
|
||||
task.unlock()
|
||||
self._wait_unlock.unlock_first()
|
13
config.py.example
Normal file
13
config.py.example
Normal file
@ -0,0 +1,13 @@
|
||||
app_id = 6
|
||||
app_hash = "eb06d4abfb49dc3eeb1aeb98ae0f581e"
|
||||
sessions = ["sessions/test1", "sessions/test2"]
|
||||
|
||||
listen_host = "127.0.0.1"
|
||||
listen_port = 8080
|
||||
|
||||
webhook = "https://example.org/webhook"
|
||||
|
||||
|
||||
|
||||
# Usage:
|
||||
# http://listen_host:listen_port/tasks/watch/add/{groupname}
|
5
tasks/__init__.py
Normal file
5
tasks/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .channel_history import ChannelHistoryReadTask
|
||||
|
||||
__all__ = [
|
||||
"ChannelHistoryReadTask"
|
||||
]
|
58
tasks/channel_history.py
Normal file
58
tasks/channel_history.py
Normal file
@ -0,0 +1,58 @@
|
||||
import pyrogram
|
||||
import aiohttp
|
||||
|
||||
from async_worker.async_worker import AsyncTask
|
||||
from pyrogram.api.types import ChannelMessagesFilterEmpty
|
||||
from pyrogram.api.types.updates import ChannelDifference, ChannelDifferenceTooLong, ChannelDifferenceEmpty
|
||||
from pyrogram.api.functions.updates.get_channel_difference import GetChannelDifference
|
||||
|
||||
from pyrogram.client.types.messages_and_media import Message as MessagePyrogram
|
||||
|
||||
|
||||
class ChannelHistoryReadTask(AsyncTask):
|
||||
channel: pyrogram.Chat
|
||||
client: pyrogram.Client
|
||||
pts: int
|
||||
webhook: str
|
||||
|
||||
def setup(self, client: pyrogram.Client, channel: pyrogram.Chat, webhook: str):
|
||||
self.client = client
|
||||
self.channel = channel
|
||||
self.pts = False
|
||||
self.webhook = webhook
|
||||
|
||||
async def process(self):
|
||||
response = await self.client.send(
|
||||
GetChannelDifference(
|
||||
channel=self.channel,
|
||||
filter=ChannelMessagesFilterEmpty(),
|
||||
pts=self.pts if self.pts else 0xFFFFFFF,
|
||||
limit=0xFFFFFFF,
|
||||
force=True
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(response, ChannelDifference):
|
||||
self.pts = response.pts
|
||||
|
||||
users = {i.id: i for i in response.users}
|
||||
chats = {i.id: i for i in response.chats}
|
||||
http = aiohttp.ClientSession()
|
||||
|
||||
for message in response.new_messages:
|
||||
message = await MessagePyrogram._parse(self.client, message, users, chats)
|
||||
await http.post(self.webhook, data=bytes(str(message), "utf8"))
|
||||
|
||||
await http.close()
|
||||
|
||||
if not response.final:
|
||||
return 1
|
||||
|
||||
return response.timeout
|
||||
|
||||
if isinstance(response, ChannelDifferenceEmpty):
|
||||
self.pts = response.pts
|
||||
return response.timeout
|
||||
|
||||
if isinstance(response, ChannelDifferenceTooLong):
|
||||
return False
|
Loading…
Reference in New Issue
Block a user