mirror of
https://github.com/nexus-stc/hyperboria
synced 2024-12-24 10:35:49 +01:00
- feat: Added missing library
- feat: Set up delays for Crossref API GitOrigin-RevId: 3448e1c2a9fdfca2f8bf95d37e1d62cc5a221577
This commit is contained in:
parent
681817ceae
commit
cca9e8be47
18
library/telegram/BUILD.bazel
Normal file
18
library/telegram/BUILD.bazel
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
load("@pip_modules//:requirements.bzl", "requirement")
|
||||||
|
load("@rules_python//python:defs.bzl", "py_library")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "telegram",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*.py"],
|
||||||
|
exclude = ["tests/**"],
|
||||||
|
),
|
||||||
|
srcs_version = "PY3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
requirement("sqlalchemy"),
|
||||||
|
requirement("telethon"),
|
||||||
|
requirement("aiokit"),
|
||||||
|
"//library/logging",
|
||||||
|
],
|
||||||
|
)
|
0
library/telegram/__init__.py
Normal file
0
library/telegram/__init__.py
Normal file
160
library/telegram/base.py
Normal file
160
library/telegram/base.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiokit import AioThing
|
||||||
|
from izihawa_utils.random import random_string
|
||||||
|
from library.logging import error_log
|
||||||
|
from telethon import (
|
||||||
|
TelegramClient,
|
||||||
|
connection,
|
||||||
|
sessions,
|
||||||
|
)
|
||||||
|
from tenacity import ( # noqa
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
wait_fixed,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .session_backend import AlchemySessionContainer
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTelegramClient(AioThing):
|
||||||
|
def __init__(self, app_id, app_hash, database, bot_token=None, mtproxy=None, flood_sleep_threshold: int = 60):
|
||||||
|
AioThing.__init__(self)
|
||||||
|
self._telegram_client = TelegramClient(
|
||||||
|
self._get_session(database),
|
||||||
|
app_id,
|
||||||
|
app_hash,
|
||||||
|
flood_sleep_threshold=flood_sleep_threshold,
|
||||||
|
**self._get_proxy(mtproxy=mtproxy),
|
||||||
|
)
|
||||||
|
self.bot_token = bot_token
|
||||||
|
|
||||||
|
def _get_session(self, database):
|
||||||
|
if database.get('drivername') == 'postgresql':
|
||||||
|
self.container = AlchemySessionContainer(
|
||||||
|
f"{database['drivername']}://"
|
||||||
|
f"{database['username']}:"
|
||||||
|
f"{database['password']}@"
|
||||||
|
f"{database['host']}:"
|
||||||
|
f"{database['port']}/"
|
||||||
|
f"{database['database']}",
|
||||||
|
session=False,
|
||||||
|
manage_tables=False,
|
||||||
|
)
|
||||||
|
return self.container.new_session(database['session_id'])
|
||||||
|
else:
|
||||||
|
return sessions.SQLiteSession(session_id=database['session_id'])
|
||||||
|
|
||||||
|
def _get_proxy(self, mtproxy=None):
|
||||||
|
if mtproxy and mtproxy.get('enabled', True):
|
||||||
|
proxy_config = mtproxy
|
||||||
|
return {
|
||||||
|
'connection': connection.tcpmtproxy.ConnectionTcpMTProxyRandomizedIntermediate,
|
||||||
|
'proxy': (proxy_config['url'], proxy_config['port'], proxy_config['secret'])
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@retry(retry=retry_if_exception_type(ConnectionError), wait=wait_fixed(5))
|
||||||
|
async def start(self):
|
||||||
|
await self._telegram_client.start(bot_token=self.bot_token)
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
return await self.disconnect()
|
||||||
|
|
||||||
|
def add_event_handler(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.add_event_handler(*args, **kwargs)
|
||||||
|
|
||||||
|
def catch_up(self):
|
||||||
|
return self._telegram_client.catch_up()
|
||||||
|
|
||||||
|
def delete_messages(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.delete_messages(*args, **kwargs)
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
return self._telegram_client.disconnect()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def disconnected(self):
|
||||||
|
return self._telegram_client.disconnected
|
||||||
|
|
||||||
|
def download_document(self, *args, **kwargs):
|
||||||
|
return self._telegram_client._download_document(
|
||||||
|
*args,
|
||||||
|
date=datetime.datetime.now(),
|
||||||
|
thumb=None,
|
||||||
|
progress_callback=None,
|
||||||
|
msg_data=None,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def edit_message(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.edit_message(*args, **kwargs)
|
||||||
|
|
||||||
|
def edit_permissions(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.edit_permissions(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward_messages(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.forward_messages(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_entity(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.get_entity(*args, **kwargs)
|
||||||
|
|
||||||
|
def get_input_entity(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.get_input_entity(*args, **kwargs)
|
||||||
|
|
||||||
|
def iter_admin_log(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.iter_admin_log(*args, **kwargs)
|
||||||
|
|
||||||
|
def iter_messages(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.iter_messages(*args, **kwargs)
|
||||||
|
|
||||||
|
def list_event_handlers(self):
|
||||||
|
return self._telegram_client.list_event_handlers()
|
||||||
|
|
||||||
|
def remove_event_handlers(self):
|
||||||
|
for handler in reversed(self.list_event_handlers()):
|
||||||
|
self._telegram_client.remove_event_handler(*handler)
|
||||||
|
|
||||||
|
def run_until_disconnected(self):
|
||||||
|
return self._telegram_client.run_until_disconnected()
|
||||||
|
|
||||||
|
def send_message(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.send_message(*args, **kwargs)
|
||||||
|
|
||||||
|
def send_file(self, *args, **kwargs):
|
||||||
|
return self._telegram_client.send_file(*args, **kwargs)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self._telegram_client(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RequestContext:
|
||||||
|
def __init__(self, bot_name, chat, request_id: str = None, request_id_length: int = 12):
|
||||||
|
self.bot_name = bot_name
|
||||||
|
self.chat = chat
|
||||||
|
self.request_id = request_id or RequestContext.generate_request_id(request_id_length)
|
||||||
|
self.default_fields = {
|
||||||
|
'bot_name': self.bot_name,
|
||||||
|
'chat_id': self.chat.id,
|
||||||
|
'request_id': self.request_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_request_id(length):
|
||||||
|
return random_string(length)
|
||||||
|
|
||||||
|
def add_default_fields(self, **fields):
|
||||||
|
self.default_fields.update(fields)
|
||||||
|
|
||||||
|
def statbox(self, **kwargs):
|
||||||
|
logging.getLogger('statbox').info(
|
||||||
|
msg=dict(
|
||||||
|
**self.default_fields,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def error_log(self, e, level=logging.ERROR, **fields):
|
||||||
|
all_fields = {**self.default_fields, **fields}
|
||||||
|
error_log(e, level=level, **all_fields)
|
3
library/telegram/session_backend/__init__.py
Normal file
3
library/telegram/session_backend/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .sqlalchemy import AlchemySessionContainer
|
||||||
|
|
||||||
|
__all__ = ['AlchemySessionContainer']
|
183
library/telegram/session_backend/core.py
Normal file
183
library/telegram/session_backend/core.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
import datetime
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
and_,
|
||||||
|
select,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
from telethon import utils
|
||||||
|
from telethon.crypto import AuthKey
|
||||||
|
from telethon.sessions.memory import _SentFileType
|
||||||
|
from telethon.tl.types import (
|
||||||
|
InputDocument,
|
||||||
|
InputPhoto,
|
||||||
|
PeerChannel,
|
||||||
|
PeerChat,
|
||||||
|
PeerUser,
|
||||||
|
updates,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .orm import AlchemySession
|
||||||
|
|
||||||
|
|
||||||
|
class AlchemyCoreSession(AlchemySession):
|
||||||
|
def _load_session(self) -> None:
|
||||||
|
t = self.Session.__table__
|
||||||
|
rows = self.engine.execute(select([t.c.dc_id, t.c.server_address, t.c.port, t.c.auth_key])
|
||||||
|
.where(t.c.session_id == self.session_id))
|
||||||
|
try:
|
||||||
|
self._dc_id, self._server_address, self._port, auth_key = next(rows)
|
||||||
|
self._auth_key = AuthKey(data=auth_key)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_auth_key(self) -> Optional[AuthKey]:
|
||||||
|
t = self.Session.__table__
|
||||||
|
rows = self.engine.execute(select([t.c.auth_key]).where(t.c.session_id == self.session_id))
|
||||||
|
try:
|
||||||
|
ak = next(rows)[0]
|
||||||
|
except (StopIteration, IndexError):
|
||||||
|
ak = None
|
||||||
|
return AuthKey(data=ak) if ak else None
|
||||||
|
|
||||||
|
def get_update_state(self, entity_id: int) -> Optional[updates.State]:
|
||||||
|
t = self.UpdateState.__table__
|
||||||
|
rows = self.engine.execute(select([t])
|
||||||
|
.where(and_(t.c.session_id == self.session_id,
|
||||||
|
t.c.entity_id == entity_id)))
|
||||||
|
try:
|
||||||
|
_, _, pts, qts, date, seq, unread_count = next(rows)
|
||||||
|
date = datetime.datetime.utcfromtimestamp(date)
|
||||||
|
return updates.State(pts, qts, date, seq, unread_count)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_update_state(self, entity_id: int, row: Any) -> None:
|
||||||
|
t = self.UpdateState.__table__
|
||||||
|
self.engine.execute(t.delete().where(and_(t.c.session_id == self.session_id,
|
||||||
|
t.c.entity_id == entity_id)))
|
||||||
|
self.engine.execute(t.insert()
|
||||||
|
.values(session_id=self.session_id, entity_id=entity_id, pts=row.pts,
|
||||||
|
qts=row.qts, date=row.date.timestamp(), seq=row.seq,
|
||||||
|
unread_count=row.unread_count))
|
||||||
|
|
||||||
|
def _update_session_table(self) -> None:
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
insert_stmt = insert(self.Session.__table__)
|
||||||
|
conn.execute(
|
||||||
|
insert_stmt.on_conflict_do_update(
|
||||||
|
index_elements=['session_id', 'dc_id'],
|
||||||
|
set_={
|
||||||
|
'session_id': insert_stmt.excluded.session_id,
|
||||||
|
'dc_id': insert_stmt.excluded.dc_id,
|
||||||
|
'server_address': insert_stmt.excluded.server_address,
|
||||||
|
'port': insert_stmt.excluded.port,
|
||||||
|
'auth_key': insert_stmt.excluded.auth_key,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
session_id=self.session_id, dc_id=self._dc_id,
|
||||||
|
server_address=self._server_address, port=self._port,
|
||||||
|
auth_key=(self._auth_key.key if self._auth_key else b'')
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
# engine.execute() autocommits
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(self.Session.__table__.delete().where(
|
||||||
|
self.Session.__table__.c.session_id == self.session_id))
|
||||||
|
conn.execute(self.Entity.__table__.delete().where(
|
||||||
|
self.Entity.__table__.c.session_id == self.session_id))
|
||||||
|
conn.execute(self.SentFile.__table__.delete().where(
|
||||||
|
self.SentFile.__table__.c.session_id == self.session_id))
|
||||||
|
conn.execute(self.UpdateState.__table__.delete().where(
|
||||||
|
self.UpdateState.__table__.c.session_id == self.session_id))
|
||||||
|
|
||||||
|
def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str
|
||||||
|
) -> Any:
|
||||||
|
return id, hash, username, phone, name
|
||||||
|
|
||||||
|
def process_entities(self, tlo: Any) -> None:
|
||||||
|
rows = self._entities_to_rows(tlo)
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
t = self.Entity.__table__
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(t.delete().where(and_(t.c.session_id == self.session_id,
|
||||||
|
t.c.id.in_([row[0] for row in rows]))))
|
||||||
|
conn.execute(t.insert(), [dict(session_id=self.session_id, id=row[0], hash=row[1],
|
||||||
|
username=row[2], phone=row[3], name=row[4])
|
||||||
|
for row in rows])
|
||||||
|
|
||||||
|
def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]:
|
||||||
|
return self._get_entity_rows_by_condition(self.Entity.__table__.c.phone == key)
|
||||||
|
|
||||||
|
def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]:
|
||||||
|
return self._get_entity_rows_by_condition(self.Entity.__table__.c.username == key)
|
||||||
|
|
||||||
|
def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]:
|
||||||
|
return self._get_entity_rows_by_condition(self.Entity.__table__.c.name == key)
|
||||||
|
|
||||||
|
def _get_entity_rows_by_condition(self, condition) -> Optional[Tuple[int, int]]:
|
||||||
|
t = self.Entity.__table__
|
||||||
|
rows = self.engine.execute(select([t.c.id, t.c.hash])
|
||||||
|
.where(and_(t.c.session_id == self.session_id, condition)))
|
||||||
|
try:
|
||||||
|
return next(rows)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]:
|
||||||
|
t = self.Entity.__table__
|
||||||
|
if exact:
|
||||||
|
rows = self.engine.execute(select([t.c.id, t.c.hash]).where(
|
||||||
|
and_(t.c.session_id == self.session_id, t.c.id == key)))
|
||||||
|
else:
|
||||||
|
ids = (
|
||||||
|
utils.get_peer_id(PeerUser(key)),
|
||||||
|
utils.get_peer_id(PeerChat(key)),
|
||||||
|
utils.get_peer_id(PeerChannel(key))
|
||||||
|
)
|
||||||
|
rows = self.engine.execute(select([t.c.id, t.c.hash]).where(
|
||||||
|
and_(t.c.session_id == self.session_id, t.c.id.in_(ids)))
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return next(rows)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]:
|
||||||
|
t = self.SentFile.__table__
|
||||||
|
rows = (self.engine.execute(select([t.c.id, t.c.hash])
|
||||||
|
.where(and_(t.c.session_id == self.session_id,
|
||||||
|
t.c.md5_digest == md5_digest,
|
||||||
|
t.c.file_size == file_size,
|
||||||
|
t.c.type == _SentFileType.from_type(cls).value))))
|
||||||
|
try:
|
||||||
|
return next(rows)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cache_file(self, md5_digest: str, file_size: int,
|
||||||
|
instance: Union[InputDocument, InputPhoto]) -> None:
|
||||||
|
if not isinstance(instance, (InputDocument, InputPhoto)):
|
||||||
|
raise TypeError("Cannot cache {} instance".format(type(instance)))
|
||||||
|
|
||||||
|
t = self.SentFile.__table__
|
||||||
|
file_type = _SentFileType.from_type(type(instance)).value
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(t.delete().where(session_id=self.session_id, md5_digest=md5_digest,
|
||||||
|
type=file_type, file_size=file_size))
|
||||||
|
conn.execute(t.insert().values(session_id=self.session_id, md5_digest=md5_digest,
|
||||||
|
type=file_type, file_size=file_size, id=instance.id,
|
||||||
|
hash=instance.access_hash))
|
56
library/telegram/session_backend/core_postgres.py
Normal file
56
library/telegram/session_backend/core_postgres.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
from telethon.sessions.memory import _SentFileType
|
||||||
|
from telethon.tl.types import (
|
||||||
|
InputDocument,
|
||||||
|
InputPhoto,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .core import AlchemyCoreSession
|
||||||
|
|
||||||
|
|
||||||
|
class AlchemyPostgresCoreSession(AlchemyCoreSession):
|
||||||
|
def set_update_state(self, entity_id: int, row: Any) -> None:
|
||||||
|
t = self.UpdateState.__table__
|
||||||
|
values = dict(pts=row.pts, qts=row.qts, date=row.date.timestamp(),
|
||||||
|
seq=row.seq, unread_count=row.unread_count)
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(insert(t)
|
||||||
|
.values(session_id=self.session_id, entity_id=entity_id, **values)
|
||||||
|
.on_conflict_do_update(constraint=t.primary_key, set_=values))
|
||||||
|
|
||||||
|
def process_entities(self, tlo: Any) -> None:
|
||||||
|
rows = self._entities_to_rows(tlo)
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
t = self.Entity.__table__
|
||||||
|
ins = insert(t)
|
||||||
|
upsert = ins.on_conflict_do_update(constraint=t.primary_key, set_={
|
||||||
|
"hash": ins.excluded.hash,
|
||||||
|
"username": ins.excluded.username,
|
||||||
|
"phone": ins.excluded.phone,
|
||||||
|
"name": ins.excluded.name,
|
||||||
|
})
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(upsert, [dict(session_id=self.session_id, id=row[0], hash=row[1],
|
||||||
|
username=row[2], phone=row[3], name=row[4])
|
||||||
|
for row in rows])
|
||||||
|
|
||||||
|
def cache_file(self, md5_digest: str, file_size: int,
|
||||||
|
instance: Union[InputDocument, InputPhoto]) -> None:
|
||||||
|
if not isinstance(instance, (InputDocument, InputPhoto)):
|
||||||
|
raise TypeError("Cannot cache {} instance".format(type(instance)))
|
||||||
|
|
||||||
|
t = self.SentFile.__table__
|
||||||
|
values = dict(id=instance.id, hash=instance.access_hash)
|
||||||
|
with self.engine.begin() as conn:
|
||||||
|
conn.execute(insert(t)
|
||||||
|
.values(session_id=self.session_id, md5_digest=md5_digest,
|
||||||
|
type=_SentFileType.from_type(type(instance)).value,
|
||||||
|
file_size=file_size, **values)
|
||||||
|
.on_conflict_do_update(constraint=t.primary_key, set_=values))
|
170
library/telegram/session_backend/orm.py
Normal file
170
library/telegram/session_backend/orm.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import datetime
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sqlalchemy import orm
|
||||||
|
from telethon import utils
|
||||||
|
from telethon.crypto import AuthKey
|
||||||
|
from telethon.sessions.memory import (
|
||||||
|
MemorySession,
|
||||||
|
_SentFileType,
|
||||||
|
)
|
||||||
|
from telethon.tl.types import (
|
||||||
|
InputDocument,
|
||||||
|
InputPhoto,
|
||||||
|
PeerChannel,
|
||||||
|
PeerChat,
|
||||||
|
PeerUser,
|
||||||
|
updates,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .sqlalchemy import AlchemySessionContainer
|
||||||
|
|
||||||
|
|
||||||
|
class AlchemySession(MemorySession):
|
||||||
|
def __init__(self, container: 'AlchemySessionContainer', session_id: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.container = container
|
||||||
|
self.db = container.db
|
||||||
|
self.engine = container.db_engine
|
||||||
|
self.Version, self.Session, self.Entity, self.SentFile, self.UpdateState = (
|
||||||
|
container.Version, container.Session, container.Entity,
|
||||||
|
container.SentFile, container.UpdateState)
|
||||||
|
self.session_id = session_id
|
||||||
|
self._load_session()
|
||||||
|
|
||||||
|
def _load_session(self) -> None:
|
||||||
|
sessions = self._db_query(self.Session).all()
|
||||||
|
session = sessions[0] if sessions else None
|
||||||
|
if session:
|
||||||
|
self._dc_id = session.dc_id
|
||||||
|
self._server_address = session.server_address
|
||||||
|
self._port = session.port
|
||||||
|
self._auth_key = AuthKey(data=session.auth_key)
|
||||||
|
|
||||||
|
def clone(self, to_instance=None) -> MemorySession:
|
||||||
|
return super().clone(MemorySession())
|
||||||
|
|
||||||
|
def _get_auth_key(self) -> Optional[AuthKey]:
|
||||||
|
sessions = self._db_query(self.Session).all()
|
||||||
|
session = sessions[0] if sessions else None
|
||||||
|
if session and session.auth_key:
|
||||||
|
return AuthKey(data=session.auth_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_dc(self, dc_id: str, server_address: str, port: int) -> None:
|
||||||
|
super().set_dc(dc_id, server_address, port)
|
||||||
|
self._update_session_table()
|
||||||
|
self._auth_key = self._get_auth_key()
|
||||||
|
|
||||||
|
def get_update_state(self, entity_id: int) -> Optional[updates.State]:
|
||||||
|
row = self.UpdateState.query.get((self.session_id, entity_id))
|
||||||
|
if row:
|
||||||
|
date = datetime.datetime.utcfromtimestamp(row.date)
|
||||||
|
return updates.State(row.pts, row.qts, date, row.seq, row.unread_count)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_update_state(self, entity_id: int, row: Any) -> None:
|
||||||
|
if row:
|
||||||
|
self.db.merge(self.UpdateState(session_id=self.session_id, entity_id=entity_id,
|
||||||
|
pts=row.pts, qts=row.qts, date=row.date.timestamp(),
|
||||||
|
seq=row.seq,
|
||||||
|
unread_count=row.unread_count))
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
@MemorySession.auth_key.setter
|
||||||
|
def auth_key(self, value: AuthKey) -> None:
|
||||||
|
self._auth_key = value
|
||||||
|
self._update_session_table()
|
||||||
|
|
||||||
|
def _update_session_table(self) -> None:
|
||||||
|
self.Session.query.filter(self.Session.session_id == self.session_id).delete()
|
||||||
|
self.db.add(self.Session(session_id=self.session_id, dc_id=self._dc_id,
|
||||||
|
server_address=self._server_address, port=self._port,
|
||||||
|
auth_key=(self._auth_key.key if self._auth_key else b'')))
|
||||||
|
|
||||||
|
def _db_query(self, dbclass: Any, *args: Any) -> orm.Query:
|
||||||
|
return dbclass.query.filter(
|
||||||
|
dbclass.session_id == self.session_id, *args
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
self.container.save()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
# Nothing to do here, connection is managed by AlchemySessionContainer.
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
self._db_query(self.Session).delete()
|
||||||
|
self._db_query(self.Entity).delete()
|
||||||
|
self._db_query(self.SentFile).delete()
|
||||||
|
self._db_query(self.UpdateState).delete()
|
||||||
|
|
||||||
|
def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str
|
||||||
|
) -> Any:
|
||||||
|
return self.Entity(session_id=self.session_id, id=id, hash=hash,
|
||||||
|
username=username, phone=phone, name=name)
|
||||||
|
|
||||||
|
def process_entities(self, tlo: Any) -> None:
|
||||||
|
rows = self._entities_to_rows(tlo)
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
self.db.merge(row)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]:
|
||||||
|
row = self._db_query(self.Entity,
|
||||||
|
self.Entity.phone == key).one_or_none()
|
||||||
|
return (row.id, row.hash) if row else None
|
||||||
|
|
||||||
|
def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]:
|
||||||
|
row = self._db_query(self.Entity,
|
||||||
|
self.Entity.username == key).one_or_none()
|
||||||
|
return (row.id, row.hash) if row else None
|
||||||
|
|
||||||
|
def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]:
|
||||||
|
row = self._db_query(self.Entity,
|
||||||
|
self.Entity.name == key).one_or_none()
|
||||||
|
return (row.id, row.hash) if row else None
|
||||||
|
|
||||||
|
def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]:
|
||||||
|
if exact:
|
||||||
|
query = self._db_query(self.Entity, self.Entity.id == key)
|
||||||
|
else:
|
||||||
|
ids = (
|
||||||
|
utils.get_peer_id(PeerUser(key)),
|
||||||
|
utils.get_peer_id(PeerChat(key)),
|
||||||
|
utils.get_peer_id(PeerChannel(key))
|
||||||
|
)
|
||||||
|
query = self._db_query(self.Entity, self.Entity.id.in_(ids))
|
||||||
|
|
||||||
|
row = query.one_or_none()
|
||||||
|
return (row.id, row.hash) if row else None
|
||||||
|
|
||||||
|
def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]:
|
||||||
|
row = self._db_query(self.SentFile,
|
||||||
|
self.SentFile.md5_digest == md5_digest,
|
||||||
|
self.SentFile.file_size == file_size,
|
||||||
|
self.SentFile.type == _SentFileType.from_type(
|
||||||
|
cls).value).one_or_none()
|
||||||
|
return (row.id, row.hash) if row else None
|
||||||
|
|
||||||
|
def cache_file(self, md5_digest: str, file_size: int,
|
||||||
|
instance: Union[InputDocument, InputPhoto]) -> None:
|
||||||
|
if not isinstance(instance, (InputDocument, InputPhoto)):
|
||||||
|
raise TypeError("Cannot cache {} instance".format(type(instance)))
|
||||||
|
|
||||||
|
self.db.merge(
|
||||||
|
self.SentFile(session_id=self.session_id, md5_digest=md5_digest, file_size=file_size,
|
||||||
|
type=_SentFileType.from_type(type(instance)).value,
|
||||||
|
id=instance.id, hash=instance.access_hash))
|
||||||
|
self.save()
|
203
library/telegram/session_backend/sqlalchemy.py
Normal file
203
library/telegram/session_backend/sqlalchemy.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import sqlalchemy as sql
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Column,
|
||||||
|
Integer,
|
||||||
|
LargeBinary,
|
||||||
|
String,
|
||||||
|
and_,
|
||||||
|
func,
|
||||||
|
orm,
|
||||||
|
select,
|
||||||
|
)
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm.scoping import scoped_session
|
||||||
|
|
||||||
|
from .core import AlchemyCoreSession
|
||||||
|
from .core_postgres import AlchemyPostgresCoreSession
|
||||||
|
from .orm import AlchemySession
|
||||||
|
|
||||||
|
LATEST_VERSION = 2
|
||||||
|
|
||||||
|
|
||||||
|
class AlchemySessionContainer:
|
||||||
|
def __init__(self, engine: Union[sql.engine.Engine, str] = None,
|
||||||
|
session: Optional[Union[orm.Session, scoped_session, bool]] = None,
|
||||||
|
table_prefix: str = "", table_base: Optional[declarative_base] = None,
|
||||||
|
manage_tables: bool = True) -> None:
|
||||||
|
if isinstance(engine, str):
|
||||||
|
engine = sql.create_engine(engine)
|
||||||
|
|
||||||
|
self.db_engine = engine
|
||||||
|
if session is None:
|
||||||
|
db_factory = orm.sessionmaker(bind=self.db_engine)
|
||||||
|
self.db = orm.scoping.scoped_session(db_factory)
|
||||||
|
elif not session:
|
||||||
|
self.db = None
|
||||||
|
else:
|
||||||
|
self.db = session
|
||||||
|
|
||||||
|
table_base = table_base or declarative_base()
|
||||||
|
(self.Version, self.Session, self.Entity,
|
||||||
|
self.SentFile, self.UpdateState) = self.create_table_classes(self.db, table_prefix,
|
||||||
|
table_base)
|
||||||
|
self.alchemy_session_class = AlchemySession
|
||||||
|
if not self.db:
|
||||||
|
# Implicit core mode if there's no ORM session.
|
||||||
|
self.core_mode = True
|
||||||
|
|
||||||
|
if manage_tables:
|
||||||
|
if not self.db:
|
||||||
|
raise ValueError("Can't manage tables without an ORM session.")
|
||||||
|
table_base.metadata.bind = self.db_engine
|
||||||
|
if not self.db_engine.dialect.has_table(self.db_engine,
|
||||||
|
self.Version.__tablename__):
|
||||||
|
table_base.metadata.create_all()
|
||||||
|
self.db.add(self.Version(version=LATEST_VERSION))
|
||||||
|
self.db.commit()
|
||||||
|
else:
|
||||||
|
self.check_and_upgrade_database()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def core_mode(self) -> bool:
|
||||||
|
return self.alchemy_session_class != AlchemySession
|
||||||
|
|
||||||
|
@core_mode.setter
|
||||||
|
def core_mode(self, val: bool) -> None:
|
||||||
|
if val:
|
||||||
|
if self.db_engine.dialect.name == "postgresql":
|
||||||
|
self.alchemy_session_class = AlchemyPostgresCoreSession
|
||||||
|
else:
|
||||||
|
self.alchemy_session_class = AlchemyCoreSession
|
||||||
|
else:
|
||||||
|
if not self.db:
|
||||||
|
raise ValueError("Can't use ORM mode without an ORM session.")
|
||||||
|
self.alchemy_session_class = AlchemySession
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_table_classes(db: scoped_session, prefix: str, base: declarative_base
|
||||||
|
) -> Tuple[Any, Any, Any, Any, Any]:
|
||||||
|
qp = db.query_property() if db else None
|
||||||
|
|
||||||
|
class Version(base):
|
||||||
|
query = qp
|
||||||
|
__tablename__ = "{prefix}version".format(prefix=prefix)
|
||||||
|
version = Column(Integer, primary_key=True)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Version('{}')".format(self.version)
|
||||||
|
|
||||||
|
class Session(base):
|
||||||
|
query = qp
|
||||||
|
__tablename__ = '{prefix}sessions'.format(prefix=prefix)
|
||||||
|
|
||||||
|
session_id = Column(String(255), primary_key=True)
|
||||||
|
dc_id = Column(Integer, primary_key=True)
|
||||||
|
server_address = Column(String(255))
|
||||||
|
port = Column(Integer)
|
||||||
|
auth_key = Column(LargeBinary)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Session('{}', {}, '{}', {}, {})".format(self.session_id, self.dc_id,
|
||||||
|
self.server_address, self.port,
|
||||||
|
self.auth_key)
|
||||||
|
|
||||||
|
class Entity(base):
|
||||||
|
query = qp
|
||||||
|
__tablename__ = '{prefix}entities'.format(prefix=prefix)
|
||||||
|
|
||||||
|
session_id = Column(String(255), primary_key=True)
|
||||||
|
id = Column(BigInteger, primary_key=True)
|
||||||
|
hash = Column(BigInteger, nullable=False)
|
||||||
|
username = Column(String(32))
|
||||||
|
phone = Column(BigInteger)
|
||||||
|
name = Column(String(255))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "Entity('{}', {}, {}, '{}', '{}', '{}')".format(self.session_id, self.id,
|
||||||
|
self.hash, self.username,
|
||||||
|
self.phone, self.name)
|
||||||
|
|
||||||
|
class SentFile(base):
|
||||||
|
query = qp
|
||||||
|
__tablename__ = '{prefix}sent_files'.format(prefix=prefix)
|
||||||
|
|
||||||
|
session_id = Column(String(255), primary_key=True)
|
||||||
|
md5_digest = Column(LargeBinary, primary_key=True)
|
||||||
|
file_size = Column(Integer, primary_key=True)
|
||||||
|
type = Column(Integer, primary_key=True)
|
||||||
|
id = Column(BigInteger)
|
||||||
|
hash = Column(BigInteger)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "SentFile('{}', {}, {}, {}, {}, {})".format(self.session_id,
|
||||||
|
self.md5_digest, self.file_size,
|
||||||
|
self.type, self.id, self.hash)
|
||||||
|
|
||||||
|
class UpdateState(base):
|
||||||
|
query = qp
|
||||||
|
__tablename__ = "{prefix}update_state".format(prefix=prefix)
|
||||||
|
|
||||||
|
session_id = Column(String(255), primary_key=True)
|
||||||
|
entity_id = Column(BigInteger, primary_key=True)
|
||||||
|
pts = Column(BigInteger)
|
||||||
|
qts = Column(BigInteger)
|
||||||
|
date = Column(BigInteger)
|
||||||
|
seq = Column(BigInteger)
|
||||||
|
unread_count = Column(Integer)
|
||||||
|
|
||||||
|
return Version, Session, Entity, SentFile, UpdateState
|
||||||
|
|
||||||
|
def _add_column(self, table: Any, column: Column) -> None:
|
||||||
|
column_name = column.compile(dialect=self.db_engine.dialect)
|
||||||
|
column_type = column.type.compile(self.db_engine.dialect)
|
||||||
|
self.db_engine.execute("ALTER TABLE {} ADD COLUMN {} {}".format(
|
||||||
|
table.__tablename__, column_name, column_type))
|
||||||
|
|
||||||
|
def check_and_upgrade_database(self) -> None:
|
||||||
|
row = self.Version.query.all()
|
||||||
|
version = row[0].version if row else 1
|
||||||
|
if version == LATEST_VERSION:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.Version.query.delete()
|
||||||
|
|
||||||
|
if version == 1:
|
||||||
|
self.UpdateState.__table__.create(self.db_engine)
|
||||||
|
version = 3
|
||||||
|
elif version == 2:
|
||||||
|
self._add_column(self.UpdateState, Column(type=Integer, name="unread_count"))
|
||||||
|
|
||||||
|
self.db.add(self.Version(version=version))
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
def new_session(self, session_id: str) -> 'AlchemySession':
|
||||||
|
return self.alchemy_session_class(self, session_id)
|
||||||
|
|
||||||
|
def has_session(self, session_id: str) -> bool:
|
||||||
|
if self.core_mode:
|
||||||
|
t = self.Session.__table__
|
||||||
|
rows = self.db_engine.execute(select([func.count(t.c.auth_key)])
|
||||||
|
.where(and_(t.c.session_id == session_id,
|
||||||
|
t.c.auth_key != b'')))
|
||||||
|
try:
|
||||||
|
count, = next(rows)
|
||||||
|
return count > 0
|
||||||
|
except StopIteration:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return self.Session.query.filter(self.Session.session_id == session_id).count() > 0
|
||||||
|
|
||||||
|
def list_sessions(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
if self.db:
|
||||||
|
self.db.commit()
|
45
library/telegram/utils.py
Normal file
45
library/telegram/utils.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import (
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
|
from telethon import (
|
||||||
|
errors,
|
||||||
|
events,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import RequestContext
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def safe_execution(
|
||||||
|
request_context: RequestContext,
|
||||||
|
on_fail: Optional[Callable[[], Awaitable]] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except events.StopPropagation:
|
||||||
|
raise
|
||||||
|
except (
|
||||||
|
errors.UserIsBlockedError,
|
||||||
|
errors.QueryIdInvalidError,
|
||||||
|
errors.MessageDeleteForbiddenError,
|
||||||
|
errors.MessageIdInvalidError,
|
||||||
|
errors.MessageNotModifiedError,
|
||||||
|
errors.ChatAdminRequiredError,
|
||||||
|
) as e:
|
||||||
|
request_context.error_log(e, level=logging.WARNING)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
request_context.error_log(e)
|
||||||
|
if on_fail:
|
||||||
|
await on_fail()
|
||||||
|
except events.StopPropagation:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
request_context.error_log(e)
|
@ -199,8 +199,9 @@ class FillDocumentOperationUpdateDocumentScimagPbFromExternalSourceAction(BaseAc
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.crossref_client = CrossrefClient(
|
self.crossref_client = CrossrefClient(
|
||||||
delay=1.0 / crossref['rps'],
|
delay=1.0 / crossref['rps'],
|
||||||
max_retries=60,
|
max_retries=crossref.get('max_retries', 15),
|
||||||
proxy_url=crossref.get('proxy_url'),
|
proxy_url=crossref.get('proxy_url'),
|
||||||
|
retry_delay=crossref.get('retry_delay', 0.5),
|
||||||
timeout=crossref.get('timeout'),
|
timeout=crossref.get('timeout'),
|
||||||
user_agent=crossref.get('user_agent'),
|
user_agent=crossref.get('user_agent'),
|
||||||
)
|
)
|
||||||
|
@ -41,7 +41,7 @@ py3_image(
|
|||||||
requirement("aiokit"),
|
requirement("aiokit"),
|
||||||
"//library/configurator",
|
"//library/configurator",
|
||||||
"//library/logging",
|
"//library/logging",
|
||||||
"//library/metrics_server",
|
,
|
||||||
"//library/telegram",
|
"//library/telegram",
|
||||||
"//nexus/hub/aioclient",
|
"//nexus/hub/aioclient",
|
||||||
"//nexus/meta_api/aioclient",
|
"//nexus/meta_api/aioclient",
|
||||||
|
@ -6,3 +6,4 @@ The bot requires three other essential parts:
|
|||||||
- Nexus Meta API (rescoring API for Summa Search server)
|
- Nexus Meta API (rescoring API for Summa Search server)
|
||||||
|
|
||||||
or their substitutions
|
or their substitutions
|
||||||
|
|
||||||
|
@ -9,5 +9,6 @@ py_binary(
|
|||||||
deps = [
|
deps = [
|
||||||
"@bazel_tools//tools/build_defs/pkg:archive",
|
"@bazel_tools//tools/build_defs/pkg:archive",
|
||||||
requirement("python-gflags"),
|
requirement("python-gflags"),
|
||||||
|
requirement("six"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user