diff --git a/library/telegram/BUILD.bazel b/library/telegram/BUILD.bazel new file mode 100644 index 0000000..3e05cb9 --- /dev/null +++ b/library/telegram/BUILD.bazel @@ -0,0 +1,18 @@ +load("@pip_modules//:requirements.bzl", "requirement") +load("@rules_python//python:defs.bzl", "py_library") + +py_library( + name = "telegram", + srcs = glob( + ["**/*.py"], + exclude = ["tests/**"], + ), + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + requirement("sqlalchemy"), + requirement("telethon"), + requirement("aiokit"), + "//library/logging", + ], +) diff --git a/library/telegram/__init__.py b/library/telegram/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/library/telegram/base.py b/library/telegram/base.py new file mode 100644 index 0000000..008d4c3 --- /dev/null +++ b/library/telegram/base.py @@ -0,0 +1,160 @@ +import datetime +import logging + +from aiokit import AioThing +from izihawa_utils.random import random_string +from library.logging import error_log +from telethon import ( + TelegramClient, + connection, + sessions, +) +from tenacity import ( # noqa + retry, + retry_if_exception_type, + wait_fixed, +) + +from .session_backend import AlchemySessionContainer + + +class BaseTelegramClient(AioThing): + def __init__(self, app_id, app_hash, database, bot_token=None, mtproxy=None, flood_sleep_threshold: int = 60): + AioThing.__init__(self) + self._telegram_client = TelegramClient( + self._get_session(database), + app_id, + app_hash, + flood_sleep_threshold=flood_sleep_threshold, + **self._get_proxy(mtproxy=mtproxy), + ) + self.bot_token = bot_token + + def _get_session(self, database): + if database.get('drivername') == 'postgresql': + self.container = AlchemySessionContainer( + f"{database['drivername']}://" + f"{database['username']}:" + f"{database['password']}@" + f"{database['host']}:" + f"{database['port']}/" + f"{database['database']}", + session=False, + manage_tables=False, + ) + return self.container.new_session(database['session_id']) + else: + return sessions.SQLiteSession(session_id=database['session_id']) + + def _get_proxy(self, mtproxy=None): + if mtproxy and mtproxy.get('enabled', True): + proxy_config = mtproxy + return { + 'connection': connection.tcpmtproxy.ConnectionTcpMTProxyRandomizedIntermediate, + 'proxy': (proxy_config['url'], proxy_config['port'], proxy_config['secret']) + } + return {} + + @retry(retry=retry_if_exception_type(ConnectionError), wait=wait_fixed(5)) + async def start(self): + await self._telegram_client.start(bot_token=self.bot_token) + + async def stop(self): + return await self.disconnect() + + def add_event_handler(self, *args, **kwargs): + return self._telegram_client.add_event_handler(*args, **kwargs) + + def catch_up(self): + return self._telegram_client.catch_up() + + def delete_messages(self, *args, **kwargs): + return self._telegram_client.delete_messages(*args, **kwargs) + + def disconnect(self): + return self._telegram_client.disconnect() + + @property + def disconnected(self): + return self._telegram_client.disconnected + + def download_document(self, *args, **kwargs): + return self._telegram_client._download_document( + *args, + date=datetime.datetime.now(), + thumb=None, + progress_callback=None, + msg_data=None, + **kwargs, + ) + + def edit_message(self, *args, **kwargs): + return self._telegram_client.edit_message(*args, **kwargs) + + def edit_permissions(self, *args, **kwargs): + return self._telegram_client.edit_permissions(*args, **kwargs) + + def forward_messages(self, *args, **kwargs): + return self._telegram_client.forward_messages(*args, **kwargs) + + def get_entity(self, *args, **kwargs): + return self._telegram_client.get_entity(*args, **kwargs) + + def get_input_entity(self, *args, **kwargs): + return self._telegram_client.get_input_entity(*args, **kwargs) + + def iter_admin_log(self, *args, **kwargs): + return self._telegram_client.iter_admin_log(*args, **kwargs) + + def iter_messages(self, *args, **kwargs): + return self._telegram_client.iter_messages(*args, **kwargs) + + def list_event_handlers(self): + return self._telegram_client.list_event_handlers() + + def remove_event_handlers(self): + for handler in reversed(self.list_event_handlers()): + self._telegram_client.remove_event_handler(*handler) + + def run_until_disconnected(self): + return self._telegram_client.run_until_disconnected() + + def send_message(self, *args, **kwargs): + return self._telegram_client.send_message(*args, **kwargs) + + def send_file(self, *args, **kwargs): + return self._telegram_client.send_file(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self._telegram_client(*args, **kwargs) + + +class RequestContext: + def __init__(self, bot_name, chat, request_id: str = None, request_id_length: int = 12): + self.bot_name = bot_name + self.chat = chat + self.request_id = request_id or RequestContext.generate_request_id(request_id_length) + self.default_fields = { + 'bot_name': self.bot_name, + 'chat_id': self.chat.id, + 'request_id': self.request_id, + } + + @staticmethod + def generate_request_id(length): + return random_string(length) + + def add_default_fields(self, **fields): + self.default_fields.update(fields) + + def statbox(self, **kwargs): + logging.getLogger('statbox').info( + msg=dict( + **self.default_fields, + **kwargs, + ), + ) + + def error_log(self, e, level=logging.ERROR, **fields): + all_fields = {**self.default_fields, **fields} + error_log(e, level=level, **all_fields) diff --git a/library/telegram/session_backend/__init__.py b/library/telegram/session_backend/__init__.py new file mode 100644 index 0000000..2192c14 --- /dev/null +++ b/library/telegram/session_backend/__init__.py @@ -0,0 +1,3 @@ +from .sqlalchemy import AlchemySessionContainer + +__all__ = ['AlchemySessionContainer'] diff --git a/library/telegram/session_backend/core.py b/library/telegram/session_backend/core.py new file mode 100644 index 0000000..476482b --- /dev/null +++ b/library/telegram/session_backend/core.py @@ -0,0 +1,183 @@ +import datetime +from typing import ( + Any, + Optional, + Tuple, + Union, +) + +from sqlalchemy import ( + and_, + select, +) +from sqlalchemy.dialects.postgresql import insert +from telethon import utils +from telethon.crypto import AuthKey +from telethon.sessions.memory import _SentFileType +from telethon.tl.types import ( + InputDocument, + InputPhoto, + PeerChannel, + PeerChat, + PeerUser, + updates, +) + +from .orm import AlchemySession + + +class AlchemyCoreSession(AlchemySession): + def _load_session(self) -> None: + t = self.Session.__table__ + rows = self.engine.execute(select([t.c.dc_id, t.c.server_address, t.c.port, t.c.auth_key]) + .where(t.c.session_id == self.session_id)) + try: + self._dc_id, self._server_address, self._port, auth_key = next(rows) + self._auth_key = AuthKey(data=auth_key) + except StopIteration: + pass + + def _get_auth_key(self) -> Optional[AuthKey]: + t = self.Session.__table__ + rows = self.engine.execute(select([t.c.auth_key]).where(t.c.session_id == self.session_id)) + try: + ak = next(rows)[0] + except (StopIteration, IndexError): + ak = None + return AuthKey(data=ak) if ak else None + + def get_update_state(self, entity_id: int) -> Optional[updates.State]: + t = self.UpdateState.__table__ + rows = self.engine.execute(select([t]) + .where(and_(t.c.session_id == self.session_id, + t.c.entity_id == entity_id))) + try: + _, _, pts, qts, date, seq, unread_count = next(rows) + date = datetime.datetime.utcfromtimestamp(date) + return updates.State(pts, qts, date, seq, unread_count) + except StopIteration: + return None + + def set_update_state(self, entity_id: int, row: Any) -> None: + t = self.UpdateState.__table__ + self.engine.execute(t.delete().where(and_(t.c.session_id == self.session_id, + t.c.entity_id == entity_id))) + self.engine.execute(t.insert() + .values(session_id=self.session_id, entity_id=entity_id, pts=row.pts, + qts=row.qts, date=row.date.timestamp(), seq=row.seq, + unread_count=row.unread_count)) + + def _update_session_table(self) -> None: + with self.engine.begin() as conn: + insert_stmt = insert(self.Session.__table__) + conn.execute( + insert_stmt.on_conflict_do_update( + index_elements=['session_id', 'dc_id'], + set_={ + 'session_id': insert_stmt.excluded.session_id, + 'dc_id': insert_stmt.excluded.dc_id, + 'server_address': insert_stmt.excluded.server_address, + 'port': insert_stmt.excluded.port, + 'auth_key': insert_stmt.excluded.auth_key, + } + ), + session_id=self.session_id, dc_id=self._dc_id, + server_address=self._server_address, port=self._port, + auth_key=(self._auth_key.key if self._auth_key else b'') + ) + + def save(self) -> None: + # engine.execute() autocommits + pass + + def delete(self) -> None: + with self.engine.begin() as conn: + conn.execute(self.Session.__table__.delete().where( + self.Session.__table__.c.session_id == self.session_id)) + conn.execute(self.Entity.__table__.delete().where( + self.Entity.__table__.c.session_id == self.session_id)) + conn.execute(self.SentFile.__table__.delete().where( + self.SentFile.__table__.c.session_id == self.session_id)) + conn.execute(self.UpdateState.__table__.delete().where( + self.UpdateState.__table__.c.session_id == self.session_id)) + + def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str + ) -> Any: + return id, hash, username, phone, name + + def process_entities(self, tlo: Any) -> None: + rows = self._entities_to_rows(tlo) + if not rows: + return + + t = self.Entity.__table__ + with self.engine.begin() as conn: + conn.execute(t.delete().where(and_(t.c.session_id == self.session_id, + t.c.id.in_([row[0] for row in rows])))) + conn.execute(t.insert(), [dict(session_id=self.session_id, id=row[0], hash=row[1], + username=row[2], phone=row[3], name=row[4]) + for row in rows]) + + def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]: + return self._get_entity_rows_by_condition(self.Entity.__table__.c.phone == key) + + def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]: + return self._get_entity_rows_by_condition(self.Entity.__table__.c.username == key) + + def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]: + return self._get_entity_rows_by_condition(self.Entity.__table__.c.name == key) + + def _get_entity_rows_by_condition(self, condition) -> Optional[Tuple[int, int]]: + t = self.Entity.__table__ + rows = self.engine.execute(select([t.c.id, t.c.hash]) + .where(and_(t.c.session_id == self.session_id, condition))) + try: + return next(rows) + except StopIteration: + return None + + def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]: + t = self.Entity.__table__ + if exact: + rows = self.engine.execute(select([t.c.id, t.c.hash]).where( + and_(t.c.session_id == self.session_id, t.c.id == key))) + else: + ids = ( + utils.get_peer_id(PeerUser(key)), + utils.get_peer_id(PeerChat(key)), + utils.get_peer_id(PeerChannel(key)) + ) + rows = self.engine.execute(select([t.c.id, t.c.hash]).where( + and_(t.c.session_id == self.session_id, t.c.id.in_(ids))) + ) + + try: + return next(rows) + except StopIteration: + return None + + def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]: + t = self.SentFile.__table__ + rows = (self.engine.execute(select([t.c.id, t.c.hash]) + .where(and_(t.c.session_id == self.session_id, + t.c.md5_digest == md5_digest, + t.c.file_size == file_size, + t.c.type == _SentFileType.from_type(cls).value)))) + try: + return next(rows) + except StopIteration: + return None + + def cache_file(self, md5_digest: str, file_size: int, + instance: Union[InputDocument, InputPhoto]) -> None: + if not isinstance(instance, (InputDocument, InputPhoto)): + raise TypeError("Cannot cache {} instance".format(type(instance))) + + t = self.SentFile.__table__ + file_type = _SentFileType.from_type(type(instance)).value + with self.engine.begin() as conn: + conn.execute(t.delete().where(session_id=self.session_id, md5_digest=md5_digest, + type=file_type, file_size=file_size)) + conn.execute(t.insert().values(session_id=self.session_id, md5_digest=md5_digest, + type=file_type, file_size=file_size, id=instance.id, + hash=instance.access_hash)) diff --git a/library/telegram/session_backend/core_postgres.py b/library/telegram/session_backend/core_postgres.py new file mode 100644 index 0000000..156ce87 --- /dev/null +++ b/library/telegram/session_backend/core_postgres.py @@ -0,0 +1,56 @@ +from typing import ( + Any, + Union, +) + +from sqlalchemy.dialects.postgresql import insert +from telethon.sessions.memory import _SentFileType +from telethon.tl.types import ( + InputDocument, + InputPhoto, +) + +from .core import AlchemyCoreSession + + +class AlchemyPostgresCoreSession(AlchemyCoreSession): + def set_update_state(self, entity_id: int, row: Any) -> None: + t = self.UpdateState.__table__ + values = dict(pts=row.pts, qts=row.qts, date=row.date.timestamp(), + seq=row.seq, unread_count=row.unread_count) + with self.engine.begin() as conn: + conn.execute(insert(t) + .values(session_id=self.session_id, entity_id=entity_id, **values) + .on_conflict_do_update(constraint=t.primary_key, set_=values)) + + def process_entities(self, tlo: Any) -> None: + rows = self._entities_to_rows(tlo) + if not rows: + return + + t = self.Entity.__table__ + ins = insert(t) + upsert = ins.on_conflict_do_update(constraint=t.primary_key, set_={ + "hash": ins.excluded.hash, + "username": ins.excluded.username, + "phone": ins.excluded.phone, + "name": ins.excluded.name, + }) + with self.engine.begin() as conn: + conn.execute(upsert, [dict(session_id=self.session_id, id=row[0], hash=row[1], + username=row[2], phone=row[3], name=row[4]) + for row in rows]) + + def cache_file(self, md5_digest: str, file_size: int, + instance: Union[InputDocument, InputPhoto]) -> None: + if not isinstance(instance, (InputDocument, InputPhoto)): + raise TypeError("Cannot cache {} instance".format(type(instance))) + + t = self.SentFile.__table__ + values = dict(id=instance.id, hash=instance.access_hash) + with self.engine.begin() as conn: + conn.execute(insert(t) + .values(session_id=self.session_id, md5_digest=md5_digest, + type=_SentFileType.from_type(type(instance)).value, + file_size=file_size, **values) + .on_conflict_do_update(constraint=t.primary_key, set_=values)) diff --git a/library/telegram/session_backend/orm.py b/library/telegram/session_backend/orm.py new file mode 100644 index 0000000..61bc6ea --- /dev/null +++ b/library/telegram/session_backend/orm.py @@ -0,0 +1,170 @@ +import datetime +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Tuple, + Union, +) + +from sqlalchemy import orm +from telethon import utils +from telethon.crypto import AuthKey +from telethon.sessions.memory import ( + MemorySession, + _SentFileType, +) +from telethon.tl.types import ( + InputDocument, + InputPhoto, + PeerChannel, + PeerChat, + PeerUser, + updates, +) + +if TYPE_CHECKING: + from .sqlalchemy import AlchemySessionContainer + + +class AlchemySession(MemorySession): + def __init__(self, container: 'AlchemySessionContainer', session_id: str) -> None: + super().__init__() + self.container = container + self.db = container.db + self.engine = container.db_engine + self.Version, self.Session, self.Entity, self.SentFile, self.UpdateState = ( + container.Version, container.Session, container.Entity, + container.SentFile, container.UpdateState) + self.session_id = session_id + self._load_session() + + def _load_session(self) -> None: + sessions = self._db_query(self.Session).all() + session = sessions[0] if sessions else None + if session: + self._dc_id = session.dc_id + self._server_address = session.server_address + self._port = session.port + self._auth_key = AuthKey(data=session.auth_key) + + def clone(self, to_instance=None) -> MemorySession: + return super().clone(MemorySession()) + + def _get_auth_key(self) -> Optional[AuthKey]: + sessions = self._db_query(self.Session).all() + session = sessions[0] if sessions else None + if session and session.auth_key: + return AuthKey(data=session.auth_key) + return None + + def set_dc(self, dc_id: str, server_address: str, port: int) -> None: + super().set_dc(dc_id, server_address, port) + self._update_session_table() + self._auth_key = self._get_auth_key() + + def get_update_state(self, entity_id: int) -> Optional[updates.State]: + row = self.UpdateState.query.get((self.session_id, entity_id)) + if row: + date = datetime.datetime.utcfromtimestamp(row.date) + return updates.State(row.pts, row.qts, date, row.seq, row.unread_count) + return None + + def set_update_state(self, entity_id: int, row: Any) -> None: + if row: + self.db.merge(self.UpdateState(session_id=self.session_id, entity_id=entity_id, + pts=row.pts, qts=row.qts, date=row.date.timestamp(), + seq=row.seq, + unread_count=row.unread_count)) + self.save() + + @MemorySession.auth_key.setter + def auth_key(self, value: AuthKey) -> None: + self._auth_key = value + self._update_session_table() + + def _update_session_table(self) -> None: + self.Session.query.filter(self.Session.session_id == self.session_id).delete() + self.db.add(self.Session(session_id=self.session_id, dc_id=self._dc_id, + server_address=self._server_address, port=self._port, + auth_key=(self._auth_key.key if self._auth_key else b''))) + + def _db_query(self, dbclass: Any, *args: Any) -> orm.Query: + return dbclass.query.filter( + dbclass.session_id == self.session_id, *args + ) + + def save(self) -> None: + self.container.save() + + def close(self) -> None: + # Nothing to do here, connection is managed by AlchemySessionContainer. + pass + + def delete(self) -> None: + self._db_query(self.Session).delete() + self._db_query(self.Entity).delete() + self._db_query(self.SentFile).delete() + self._db_query(self.UpdateState).delete() + + def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str + ) -> Any: + return self.Entity(session_id=self.session_id, id=id, hash=hash, + username=username, phone=phone, name=name) + + def process_entities(self, tlo: Any) -> None: + rows = self._entities_to_rows(tlo) + if not rows: + return + + for row in rows: + self.db.merge(row) + self.save() + + def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]: + row = self._db_query(self.Entity, + self.Entity.phone == key).one_or_none() + return (row.id, row.hash) if row else None + + def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]: + row = self._db_query(self.Entity, + self.Entity.username == key).one_or_none() + return (row.id, row.hash) if row else None + + def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]: + row = self._db_query(self.Entity, + self.Entity.name == key).one_or_none() + return (row.id, row.hash) if row else None + + def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]: + if exact: + query = self._db_query(self.Entity, self.Entity.id == key) + else: + ids = ( + utils.get_peer_id(PeerUser(key)), + utils.get_peer_id(PeerChat(key)), + utils.get_peer_id(PeerChannel(key)) + ) + query = self._db_query(self.Entity, self.Entity.id.in_(ids)) + + row = query.one_or_none() + return (row.id, row.hash) if row else None + + def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]: + row = self._db_query(self.SentFile, + self.SentFile.md5_digest == md5_digest, + self.SentFile.file_size == file_size, + self.SentFile.type == _SentFileType.from_type( + cls).value).one_or_none() + return (row.id, row.hash) if row else None + + def cache_file(self, md5_digest: str, file_size: int, + instance: Union[InputDocument, InputPhoto]) -> None: + if not isinstance(instance, (InputDocument, InputPhoto)): + raise TypeError("Cannot cache {} instance".format(type(instance))) + + self.db.merge( + self.SentFile(session_id=self.session_id, md5_digest=md5_digest, file_size=file_size, + type=_SentFileType.from_type(type(instance)).value, + id=instance.id, hash=instance.access_hash)) + self.save() diff --git a/library/telegram/session_backend/sqlalchemy.py b/library/telegram/session_backend/sqlalchemy.py new file mode 100644 index 0000000..29d84c2 --- /dev/null +++ b/library/telegram/session_backend/sqlalchemy.py @@ -0,0 +1,203 @@ +from typing import ( + Any, + Optional, + Tuple, + Union, +) + +import sqlalchemy as sql +from sqlalchemy import ( + BigInteger, + Column, + Integer, + LargeBinary, + String, + and_, + func, + orm, + select, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm.scoping import scoped_session + +from .core import AlchemyCoreSession +from .core_postgres import AlchemyPostgresCoreSession +from .orm import AlchemySession + +LATEST_VERSION = 2 + + +class AlchemySessionContainer: + def __init__(self, engine: Union[sql.engine.Engine, str] = None, + session: Optional[Union[orm.Session, scoped_session, bool]] = None, + table_prefix: str = "", table_base: Optional[declarative_base] = None, + manage_tables: bool = True) -> None: + if isinstance(engine, str): + engine = sql.create_engine(engine) + + self.db_engine = engine + if session is None: + db_factory = orm.sessionmaker(bind=self.db_engine) + self.db = orm.scoping.scoped_session(db_factory) + elif not session: + self.db = None + else: + self.db = session + + table_base = table_base or declarative_base() + (self.Version, self.Session, self.Entity, + self.SentFile, self.UpdateState) = self.create_table_classes(self.db, table_prefix, + table_base) + self.alchemy_session_class = AlchemySession + if not self.db: + # Implicit core mode if there's no ORM session. + self.core_mode = True + + if manage_tables: + if not self.db: + raise ValueError("Can't manage tables without an ORM session.") + table_base.metadata.bind = self.db_engine + if not self.db_engine.dialect.has_table(self.db_engine, + self.Version.__tablename__): + table_base.metadata.create_all() + self.db.add(self.Version(version=LATEST_VERSION)) + self.db.commit() + else: + self.check_and_upgrade_database() + + @property + def core_mode(self) -> bool: + return self.alchemy_session_class != AlchemySession + + @core_mode.setter + def core_mode(self, val: bool) -> None: + if val: + if self.db_engine.dialect.name == "postgresql": + self.alchemy_session_class = AlchemyPostgresCoreSession + else: + self.alchemy_session_class = AlchemyCoreSession + else: + if not self.db: + raise ValueError("Can't use ORM mode without an ORM session.") + self.alchemy_session_class = AlchemySession + + @staticmethod + def create_table_classes(db: scoped_session, prefix: str, base: declarative_base + ) -> Tuple[Any, Any, Any, Any, Any]: + qp = db.query_property() if db else None + + class Version(base): + query = qp + __tablename__ = "{prefix}version".format(prefix=prefix) + version = Column(Integer, primary_key=True) + + def __str__(self): + return "Version('{}')".format(self.version) + + class Session(base): + query = qp + __tablename__ = '{prefix}sessions'.format(prefix=prefix) + + session_id = Column(String(255), primary_key=True) + dc_id = Column(Integer, primary_key=True) + server_address = Column(String(255)) + port = Column(Integer) + auth_key = Column(LargeBinary) + + def __str__(self): + return "Session('{}', {}, '{}', {}, {})".format(self.session_id, self.dc_id, + self.server_address, self.port, + self.auth_key) + + class Entity(base): + query = qp + __tablename__ = '{prefix}entities'.format(prefix=prefix) + + session_id = Column(String(255), primary_key=True) + id = Column(BigInteger, primary_key=True) + hash = Column(BigInteger, nullable=False) + username = Column(String(32)) + phone = Column(BigInteger) + name = Column(String(255)) + + def __str__(self): + return "Entity('{}', {}, {}, '{}', '{}', '{}')".format(self.session_id, self.id, + self.hash, self.username, + self.phone, self.name) + + class SentFile(base): + query = qp + __tablename__ = '{prefix}sent_files'.format(prefix=prefix) + + session_id = Column(String(255), primary_key=True) + md5_digest = Column(LargeBinary, primary_key=True) + file_size = Column(Integer, primary_key=True) + type = Column(Integer, primary_key=True) + id = Column(BigInteger) + hash = Column(BigInteger) + + def __str__(self): + return "SentFile('{}', {}, {}, {}, {}, {})".format(self.session_id, + self.md5_digest, self.file_size, + self.type, self.id, self.hash) + + class UpdateState(base): + query = qp + __tablename__ = "{prefix}update_state".format(prefix=prefix) + + session_id = Column(String(255), primary_key=True) + entity_id = Column(BigInteger, primary_key=True) + pts = Column(BigInteger) + qts = Column(BigInteger) + date = Column(BigInteger) + seq = Column(BigInteger) + unread_count = Column(Integer) + + return Version, Session, Entity, SentFile, UpdateState + + def _add_column(self, table: Any, column: Column) -> None: + column_name = column.compile(dialect=self.db_engine.dialect) + column_type = column.type.compile(self.db_engine.dialect) + self.db_engine.execute("ALTER TABLE {} ADD COLUMN {} {}".format( + table.__tablename__, column_name, column_type)) + + def check_and_upgrade_database(self) -> None: + row = self.Version.query.all() + version = row[0].version if row else 1 + if version == LATEST_VERSION: + return + + self.Version.query.delete() + + if version == 1: + self.UpdateState.__table__.create(self.db_engine) + version = 3 + elif version == 2: + self._add_column(self.UpdateState, Column(type=Integer, name="unread_count")) + + self.db.add(self.Version(version=version)) + self.db.commit() + + def new_session(self, session_id: str) -> 'AlchemySession': + return self.alchemy_session_class(self, session_id) + + def has_session(self, session_id: str) -> bool: + if self.core_mode: + t = self.Session.__table__ + rows = self.db_engine.execute(select([func.count(t.c.auth_key)]) + .where(and_(t.c.session_id == session_id, + t.c.auth_key != b''))) + try: + count, = next(rows) + return count > 0 + except StopIteration: + return False + else: + return self.Session.query.filter(self.Session.session_id == session_id).count() > 0 + + def list_sessions(self): + return + + def save(self) -> None: + if self.db: + self.db.commit() diff --git a/library/telegram/utils.py b/library/telegram/utils.py new file mode 100644 index 0000000..d8e6ef8 --- /dev/null +++ b/library/telegram/utils.py @@ -0,0 +1,45 @@ +import logging +import traceback +from contextlib import asynccontextmanager +from typing import ( + Awaitable, + Callable, + Optional, +) + +from telethon import ( + errors, + events, +) + +from .base import RequestContext + + +@asynccontextmanager +async def safe_execution( + request_context: RequestContext, + on_fail: Optional[Callable[[], Awaitable]] = None, +): + try: + try: + yield + except events.StopPropagation: + raise + except ( + errors.UserIsBlockedError, + errors.QueryIdInvalidError, + errors.MessageDeleteForbiddenError, + errors.MessageIdInvalidError, + errors.MessageNotModifiedError, + errors.ChatAdminRequiredError, + ) as e: + request_context.error_log(e, level=logging.WARNING) + except Exception as e: + traceback.print_exc() + request_context.error_log(e) + if on_fail: + await on_fail() + except events.StopPropagation: + raise + except Exception as e: + request_context.error_log(e) diff --git a/nexus/actions/update_document_scimag.py b/nexus/actions/update_document_scimag.py index c5a6c67..54b8f56 100644 --- a/nexus/actions/update_document_scimag.py +++ b/nexus/actions/update_document_scimag.py @@ -199,8 +199,9 @@ class FillDocumentOperationUpdateDocumentScimagPbFromExternalSourceAction(BaseAc super().__init__() self.crossref_client = CrossrefClient( delay=1.0 / crossref['rps'], - max_retries=60, + max_retries=crossref.get('max_retries', 15), proxy_url=crossref.get('proxy_url'), + retry_delay=crossref.get('retry_delay', 0.5), timeout=crossref.get('timeout'), user_agent=crossref.get('user_agent'), ) diff --git a/nexus/bot/BUILD.bazel b/nexus/bot/BUILD.bazel index 4af0cf0..d700762 100644 --- a/nexus/bot/BUILD.bazel +++ b/nexus/bot/BUILD.bazel @@ -41,7 +41,7 @@ py3_image( requirement("aiokit"), "//library/configurator", "//library/logging", - "//library/metrics_server", + , "//library/telegram", "//nexus/hub/aioclient", "//nexus/meta_api/aioclient", diff --git a/nexus/bot/README.md b/nexus/bot/README.md index e9ed55a..66c5c12 100644 --- a/nexus/bot/README.md +++ b/nexus/bot/README.md @@ -5,4 +5,5 @@ The bot requires three other essential parts: - Nexus Hub API (managing files) - Nexus Meta API (rescoring API for Summa Search server) -or their substitutions \ No newline at end of file +or their substitutions + diff --git a/rules/packaging/BUILD.bazel b/rules/packaging/BUILD.bazel index 7752e39..c8d2c4c 100644 --- a/rules/packaging/BUILD.bazel +++ b/rules/packaging/BUILD.bazel @@ -9,5 +9,6 @@ py_binary( deps = [ "@bazel_tools//tools/build_defs/pkg:archive", requirement("python-gflags"), + requirement("six"), ], )