mirror of
https://github.com/nexus-stc/hyperboria
synced 2025-01-06 00:35:55 +01:00
Merge pull request #16 from the-superpirate/master
- feat: Added missing library
This commit is contained in:
commit
6aa7307fa0
18
library/telegram/BUILD.bazel
Normal file
18
library/telegram/BUILD.bazel
Normal file
@ -0,0 +1,18 @@
|
||||
load("@pip_modules//:requirements.bzl", "requirement")
|
||||
load("@rules_python//python:defs.bzl", "py_library")
|
||||
|
||||
py_library(
|
||||
name = "telegram",
|
||||
srcs = glob(
|
||||
["**/*.py"],
|
||||
exclude = ["tests/**"],
|
||||
),
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
requirement("sqlalchemy"),
|
||||
requirement("telethon"),
|
||||
requirement("aiokit"),
|
||||
"//library/logging",
|
||||
],
|
||||
)
|
0
library/telegram/__init__.py
Normal file
0
library/telegram/__init__.py
Normal file
160
library/telegram/base.py
Normal file
160
library/telegram/base.py
Normal file
@ -0,0 +1,160 @@
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
from aiokit import AioThing
|
||||
from izihawa_utils.random import random_string
|
||||
from library.logging import error_log
|
||||
from telethon import (
|
||||
TelegramClient,
|
||||
connection,
|
||||
sessions,
|
||||
)
|
||||
from tenacity import ( # noqa
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
wait_fixed,
|
||||
)
|
||||
|
||||
from .session_backend import AlchemySessionContainer
|
||||
|
||||
|
||||
class BaseTelegramClient(AioThing):
|
||||
def __init__(self, app_id, app_hash, database, bot_token=None, mtproxy=None, flood_sleep_threshold: int = 60):
|
||||
AioThing.__init__(self)
|
||||
self._telegram_client = TelegramClient(
|
||||
self._get_session(database),
|
||||
app_id,
|
||||
app_hash,
|
||||
flood_sleep_threshold=flood_sleep_threshold,
|
||||
**self._get_proxy(mtproxy=mtproxy),
|
||||
)
|
||||
self.bot_token = bot_token
|
||||
|
||||
def _get_session(self, database):
|
||||
if database.get('drivername') == 'postgresql':
|
||||
self.container = AlchemySessionContainer(
|
||||
f"{database['drivername']}://"
|
||||
f"{database['username']}:"
|
||||
f"{database['password']}@"
|
||||
f"{database['host']}:"
|
||||
f"{database['port']}/"
|
||||
f"{database['database']}",
|
||||
session=False,
|
||||
manage_tables=False,
|
||||
)
|
||||
return self.container.new_session(database['session_id'])
|
||||
else:
|
||||
return sessions.SQLiteSession(session_id=database['session_id'])
|
||||
|
||||
def _get_proxy(self, mtproxy=None):
|
||||
if mtproxy and mtproxy.get('enabled', True):
|
||||
proxy_config = mtproxy
|
||||
return {
|
||||
'connection': connection.tcpmtproxy.ConnectionTcpMTProxyRandomizedIntermediate,
|
||||
'proxy': (proxy_config['url'], proxy_config['port'], proxy_config['secret'])
|
||||
}
|
||||
return {}
|
||||
|
||||
@retry(retry=retry_if_exception_type(ConnectionError), wait=wait_fixed(5))
|
||||
async def start(self):
|
||||
await self._telegram_client.start(bot_token=self.bot_token)
|
||||
|
||||
async def stop(self):
|
||||
return await self.disconnect()
|
||||
|
||||
def add_event_handler(self, *args, **kwargs):
|
||||
return self._telegram_client.add_event_handler(*args, **kwargs)
|
||||
|
||||
def catch_up(self):
|
||||
return self._telegram_client.catch_up()
|
||||
|
||||
def delete_messages(self, *args, **kwargs):
|
||||
return self._telegram_client.delete_messages(*args, **kwargs)
|
||||
|
||||
def disconnect(self):
|
||||
return self._telegram_client.disconnect()
|
||||
|
||||
@property
|
||||
def disconnected(self):
|
||||
return self._telegram_client.disconnected
|
||||
|
||||
def download_document(self, *args, **kwargs):
|
||||
return self._telegram_client._download_document(
|
||||
*args,
|
||||
date=datetime.datetime.now(),
|
||||
thumb=None,
|
||||
progress_callback=None,
|
||||
msg_data=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def edit_message(self, *args, **kwargs):
|
||||
return self._telegram_client.edit_message(*args, **kwargs)
|
||||
|
||||
def edit_permissions(self, *args, **kwargs):
|
||||
return self._telegram_client.edit_permissions(*args, **kwargs)
|
||||
|
||||
def forward_messages(self, *args, **kwargs):
|
||||
return self._telegram_client.forward_messages(*args, **kwargs)
|
||||
|
||||
def get_entity(self, *args, **kwargs):
|
||||
return self._telegram_client.get_entity(*args, **kwargs)
|
||||
|
||||
def get_input_entity(self, *args, **kwargs):
|
||||
return self._telegram_client.get_input_entity(*args, **kwargs)
|
||||
|
||||
def iter_admin_log(self, *args, **kwargs):
|
||||
return self._telegram_client.iter_admin_log(*args, **kwargs)
|
||||
|
||||
def iter_messages(self, *args, **kwargs):
|
||||
return self._telegram_client.iter_messages(*args, **kwargs)
|
||||
|
||||
def list_event_handlers(self):
|
||||
return self._telegram_client.list_event_handlers()
|
||||
|
||||
def remove_event_handlers(self):
|
||||
for handler in reversed(self.list_event_handlers()):
|
||||
self._telegram_client.remove_event_handler(*handler)
|
||||
|
||||
def run_until_disconnected(self):
|
||||
return self._telegram_client.run_until_disconnected()
|
||||
|
||||
def send_message(self, *args, **kwargs):
|
||||
return self._telegram_client.send_message(*args, **kwargs)
|
||||
|
||||
def send_file(self, *args, **kwargs):
|
||||
return self._telegram_client.send_file(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._telegram_client(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestContext:
|
||||
def __init__(self, bot_name, chat, request_id: str = None, request_id_length: int = 12):
|
||||
self.bot_name = bot_name
|
||||
self.chat = chat
|
||||
self.request_id = request_id or RequestContext.generate_request_id(request_id_length)
|
||||
self.default_fields = {
|
||||
'bot_name': self.bot_name,
|
||||
'chat_id': self.chat.id,
|
||||
'request_id': self.request_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def generate_request_id(length):
|
||||
return random_string(length)
|
||||
|
||||
def add_default_fields(self, **fields):
|
||||
self.default_fields.update(fields)
|
||||
|
||||
def statbox(self, **kwargs):
|
||||
logging.getLogger('statbox').info(
|
||||
msg=dict(
|
||||
**self.default_fields,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
def error_log(self, e, level=logging.ERROR, **fields):
|
||||
all_fields = {**self.default_fields, **fields}
|
||||
error_log(e, level=level, **all_fields)
|
3
library/telegram/session_backend/__init__.py
Normal file
3
library/telegram/session_backend/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .sqlalchemy import AlchemySessionContainer
|
||||
|
||||
__all__ = ['AlchemySessionContainer']
|
183
library/telegram/session_backend/core.py
Normal file
183
library/telegram/session_backend/core.py
Normal file
@ -0,0 +1,183 @@
|
||||
import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from sqlalchemy import (
|
||||
and_,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from telethon import utils
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon.sessions.memory import _SentFileType
|
||||
from telethon.tl.types import (
|
||||
InputDocument,
|
||||
InputPhoto,
|
||||
PeerChannel,
|
||||
PeerChat,
|
||||
PeerUser,
|
||||
updates,
|
||||
)
|
||||
|
||||
from .orm import AlchemySession
|
||||
|
||||
|
||||
class AlchemyCoreSession(AlchemySession):
|
||||
def _load_session(self) -> None:
|
||||
t = self.Session.__table__
|
||||
rows = self.engine.execute(select([t.c.dc_id, t.c.server_address, t.c.port, t.c.auth_key])
|
||||
.where(t.c.session_id == self.session_id))
|
||||
try:
|
||||
self._dc_id, self._server_address, self._port, auth_key = next(rows)
|
||||
self._auth_key = AuthKey(data=auth_key)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
def _get_auth_key(self) -> Optional[AuthKey]:
|
||||
t = self.Session.__table__
|
||||
rows = self.engine.execute(select([t.c.auth_key]).where(t.c.session_id == self.session_id))
|
||||
try:
|
||||
ak = next(rows)[0]
|
||||
except (StopIteration, IndexError):
|
||||
ak = None
|
||||
return AuthKey(data=ak) if ak else None
|
||||
|
||||
def get_update_state(self, entity_id: int) -> Optional[updates.State]:
|
||||
t = self.UpdateState.__table__
|
||||
rows = self.engine.execute(select([t])
|
||||
.where(and_(t.c.session_id == self.session_id,
|
||||
t.c.entity_id == entity_id)))
|
||||
try:
|
||||
_, _, pts, qts, date, seq, unread_count = next(rows)
|
||||
date = datetime.datetime.utcfromtimestamp(date)
|
||||
return updates.State(pts, qts, date, seq, unread_count)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def set_update_state(self, entity_id: int, row: Any) -> None:
|
||||
t = self.UpdateState.__table__
|
||||
self.engine.execute(t.delete().where(and_(t.c.session_id == self.session_id,
|
||||
t.c.entity_id == entity_id)))
|
||||
self.engine.execute(t.insert()
|
||||
.values(session_id=self.session_id, entity_id=entity_id, pts=row.pts,
|
||||
qts=row.qts, date=row.date.timestamp(), seq=row.seq,
|
||||
unread_count=row.unread_count))
|
||||
|
||||
def _update_session_table(self) -> None:
|
||||
with self.engine.begin() as conn:
|
||||
insert_stmt = insert(self.Session.__table__)
|
||||
conn.execute(
|
||||
insert_stmt.on_conflict_do_update(
|
||||
index_elements=['session_id', 'dc_id'],
|
||||
set_={
|
||||
'session_id': insert_stmt.excluded.session_id,
|
||||
'dc_id': insert_stmt.excluded.dc_id,
|
||||
'server_address': insert_stmt.excluded.server_address,
|
||||
'port': insert_stmt.excluded.port,
|
||||
'auth_key': insert_stmt.excluded.auth_key,
|
||||
}
|
||||
),
|
||||
session_id=self.session_id, dc_id=self._dc_id,
|
||||
server_address=self._server_address, port=self._port,
|
||||
auth_key=(self._auth_key.key if self._auth_key else b'')
|
||||
)
|
||||
|
||||
def save(self) -> None:
|
||||
# engine.execute() autocommits
|
||||
pass
|
||||
|
||||
def delete(self) -> None:
|
||||
with self.engine.begin() as conn:
|
||||
conn.execute(self.Session.__table__.delete().where(
|
||||
self.Session.__table__.c.session_id == self.session_id))
|
||||
conn.execute(self.Entity.__table__.delete().where(
|
||||
self.Entity.__table__.c.session_id == self.session_id))
|
||||
conn.execute(self.SentFile.__table__.delete().where(
|
||||
self.SentFile.__table__.c.session_id == self.session_id))
|
||||
conn.execute(self.UpdateState.__table__.delete().where(
|
||||
self.UpdateState.__table__.c.session_id == self.session_id))
|
||||
|
||||
def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str
|
||||
) -> Any:
|
||||
return id, hash, username, phone, name
|
||||
|
||||
def process_entities(self, tlo: Any) -> None:
|
||||
rows = self._entities_to_rows(tlo)
|
||||
if not rows:
|
||||
return
|
||||
|
||||
t = self.Entity.__table__
|
||||
with self.engine.begin() as conn:
|
||||
conn.execute(t.delete().where(and_(t.c.session_id == self.session_id,
|
||||
t.c.id.in_([row[0] for row in rows]))))
|
||||
conn.execute(t.insert(), [dict(session_id=self.session_id, id=row[0], hash=row[1],
|
||||
username=row[2], phone=row[3], name=row[4])
|
||||
for row in rows])
|
||||
|
||||
def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]:
|
||||
return self._get_entity_rows_by_condition(self.Entity.__table__.c.phone == key)
|
||||
|
||||
def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]:
|
||||
return self._get_entity_rows_by_condition(self.Entity.__table__.c.username == key)
|
||||
|
||||
def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]:
|
||||
return self._get_entity_rows_by_condition(self.Entity.__table__.c.name == key)
|
||||
|
||||
def _get_entity_rows_by_condition(self, condition) -> Optional[Tuple[int, int]]:
|
||||
t = self.Entity.__table__
|
||||
rows = self.engine.execute(select([t.c.id, t.c.hash])
|
||||
.where(and_(t.c.session_id == self.session_id, condition)))
|
||||
try:
|
||||
return next(rows)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]:
|
||||
t = self.Entity.__table__
|
||||
if exact:
|
||||
rows = self.engine.execute(select([t.c.id, t.c.hash]).where(
|
||||
and_(t.c.session_id == self.session_id, t.c.id == key)))
|
||||
else:
|
||||
ids = (
|
||||
utils.get_peer_id(PeerUser(key)),
|
||||
utils.get_peer_id(PeerChat(key)),
|
||||
utils.get_peer_id(PeerChannel(key))
|
||||
)
|
||||
rows = self.engine.execute(select([t.c.id, t.c.hash]).where(
|
||||
and_(t.c.session_id == self.session_id, t.c.id.in_(ids)))
|
||||
)
|
||||
|
||||
try:
|
||||
return next(rows)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]:
|
||||
t = self.SentFile.__table__
|
||||
rows = (self.engine.execute(select([t.c.id, t.c.hash])
|
||||
.where(and_(t.c.session_id == self.session_id,
|
||||
t.c.md5_digest == md5_digest,
|
||||
t.c.file_size == file_size,
|
||||
t.c.type == _SentFileType.from_type(cls).value))))
|
||||
try:
|
||||
return next(rows)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def cache_file(self, md5_digest: str, file_size: int,
|
||||
instance: Union[InputDocument, InputPhoto]) -> None:
|
||||
if not isinstance(instance, (InputDocument, InputPhoto)):
|
||||
raise TypeError("Cannot cache {} instance".format(type(instance)))
|
||||
|
||||
t = self.SentFile.__table__
|
||||
file_type = _SentFileType.from_type(type(instance)).value
|
||||
with self.engine.begin() as conn:
|
||||
conn.execute(t.delete().where(session_id=self.session_id, md5_digest=md5_digest,
|
||||
type=file_type, file_size=file_size))
|
||||
conn.execute(t.insert().values(session_id=self.session_id, md5_digest=md5_digest,
|
||||
type=file_type, file_size=file_size, id=instance.id,
|
||||
hash=instance.access_hash))
|
56
library/telegram/session_backend/core_postgres.py
Normal file
56
library/telegram/session_backend/core_postgres.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Union,
|
||||
)
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from telethon.sessions.memory import _SentFileType
|
||||
from telethon.tl.types import (
|
||||
InputDocument,
|
||||
InputPhoto,
|
||||
)
|
||||
|
||||
from .core import AlchemyCoreSession
|
||||
|
||||
|
||||
class AlchemyPostgresCoreSession(AlchemyCoreSession):
|
||||
def set_update_state(self, entity_id: int, row: Any) -> None:
|
||||
t = self.UpdateState.__table__
|
||||
values = dict(pts=row.pts, qts=row.qts, date=row.date.timestamp(),
|
||||
seq=row.seq, unread_count=row.unread_count)
|
||||
with self.engine.begin() as conn:
|
||||
conn.execute(insert(t)
|
||||
.values(session_id=self.session_id, entity_id=entity_id, **values)
|
||||
.on_conflict_do_update(constraint=t.primary_key, set_=values))
|
||||
|
||||
def process_entities(self, tlo: Any) -> None:
|
||||
rows = self._entities_to_rows(tlo)
|
||||
if not rows:
|
||||
return
|
||||
|
||||
t = self.Entity.__table__
|
||||
ins = insert(t)
|
||||
upsert = ins.on_conflict_do_update(constraint=t.primary_key, set_={
|
||||
"hash": ins.excluded.hash,
|
||||
"username": ins.excluded.username,
|
||||
"phone": ins.excluded.phone,
|
||||
"name": ins.excluded.name,
|
||||
})
|
||||
with self.engine.begin() as conn:
|
||||
conn.execute(upsert, [dict(session_id=self.session_id, id=row[0], hash=row[1],
|
||||
username=row[2], phone=row[3], name=row[4])
|
||||
for row in rows])
|
||||
|
||||
def cache_file(self, md5_digest: str, file_size: int,
|
||||
instance: Union[InputDocument, InputPhoto]) -> None:
|
||||
if not isinstance(instance, (InputDocument, InputPhoto)):
|
||||
raise TypeError("Cannot cache {} instance".format(type(instance)))
|
||||
|
||||
t = self.SentFile.__table__
|
||||
values = dict(id=instance.id, hash=instance.access_hash)
|
||||
with self.engine.begin() as conn:
|
||||
conn.execute(insert(t)
|
||||
.values(session_id=self.session_id, md5_digest=md5_digest,
|
||||
type=_SentFileType.from_type(type(instance)).value,
|
||||
file_size=file_size, **values)
|
||||
.on_conflict_do_update(constraint=t.primary_key, set_=values))
|
170
library/telegram/session_backend/orm.py
Normal file
170
library/telegram/session_backend/orm.py
Normal file
@ -0,0 +1,170 @@
|
||||
import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from sqlalchemy import orm
|
||||
from telethon import utils
|
||||
from telethon.crypto import AuthKey
|
||||
from telethon.sessions.memory import (
|
||||
MemorySession,
|
||||
_SentFileType,
|
||||
)
|
||||
from telethon.tl.types import (
|
||||
InputDocument,
|
||||
InputPhoto,
|
||||
PeerChannel,
|
||||
PeerChat,
|
||||
PeerUser,
|
||||
updates,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sqlalchemy import AlchemySessionContainer
|
||||
|
||||
|
||||
class AlchemySession(MemorySession):
|
||||
def __init__(self, container: 'AlchemySessionContainer', session_id: str) -> None:
|
||||
super().__init__()
|
||||
self.container = container
|
||||
self.db = container.db
|
||||
self.engine = container.db_engine
|
||||
self.Version, self.Session, self.Entity, self.SentFile, self.UpdateState = (
|
||||
container.Version, container.Session, container.Entity,
|
||||
container.SentFile, container.UpdateState)
|
||||
self.session_id = session_id
|
||||
self._load_session()
|
||||
|
||||
def _load_session(self) -> None:
|
||||
sessions = self._db_query(self.Session).all()
|
||||
session = sessions[0] if sessions else None
|
||||
if session:
|
||||
self._dc_id = session.dc_id
|
||||
self._server_address = session.server_address
|
||||
self._port = session.port
|
||||
self._auth_key = AuthKey(data=session.auth_key)
|
||||
|
||||
def clone(self, to_instance=None) -> MemorySession:
|
||||
return super().clone(MemorySession())
|
||||
|
||||
def _get_auth_key(self) -> Optional[AuthKey]:
|
||||
sessions = self._db_query(self.Session).all()
|
||||
session = sessions[0] if sessions else None
|
||||
if session and session.auth_key:
|
||||
return AuthKey(data=session.auth_key)
|
||||
return None
|
||||
|
||||
def set_dc(self, dc_id: str, server_address: str, port: int) -> None:
|
||||
super().set_dc(dc_id, server_address, port)
|
||||
self._update_session_table()
|
||||
self._auth_key = self._get_auth_key()
|
||||
|
||||
def get_update_state(self, entity_id: int) -> Optional[updates.State]:
|
||||
row = self.UpdateState.query.get((self.session_id, entity_id))
|
||||
if row:
|
||||
date = datetime.datetime.utcfromtimestamp(row.date)
|
||||
return updates.State(row.pts, row.qts, date, row.seq, row.unread_count)
|
||||
return None
|
||||
|
||||
def set_update_state(self, entity_id: int, row: Any) -> None:
|
||||
if row:
|
||||
self.db.merge(self.UpdateState(session_id=self.session_id, entity_id=entity_id,
|
||||
pts=row.pts, qts=row.qts, date=row.date.timestamp(),
|
||||
seq=row.seq,
|
||||
unread_count=row.unread_count))
|
||||
self.save()
|
||||
|
||||
@MemorySession.auth_key.setter
|
||||
def auth_key(self, value: AuthKey) -> None:
|
||||
self._auth_key = value
|
||||
self._update_session_table()
|
||||
|
||||
def _update_session_table(self) -> None:
|
||||
self.Session.query.filter(self.Session.session_id == self.session_id).delete()
|
||||
self.db.add(self.Session(session_id=self.session_id, dc_id=self._dc_id,
|
||||
server_address=self._server_address, port=self._port,
|
||||
auth_key=(self._auth_key.key if self._auth_key else b'')))
|
||||
|
||||
def _db_query(self, dbclass: Any, *args: Any) -> orm.Query:
|
||||
return dbclass.query.filter(
|
||||
dbclass.session_id == self.session_id, *args
|
||||
)
|
||||
|
||||
def save(self) -> None:
|
||||
self.container.save()
|
||||
|
||||
def close(self) -> None:
|
||||
# Nothing to do here, connection is managed by AlchemySessionContainer.
|
||||
pass
|
||||
|
||||
def delete(self) -> None:
|
||||
self._db_query(self.Session).delete()
|
||||
self._db_query(self.Entity).delete()
|
||||
self._db_query(self.SentFile).delete()
|
||||
self._db_query(self.UpdateState).delete()
|
||||
|
||||
def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str
|
||||
) -> Any:
|
||||
return self.Entity(session_id=self.session_id, id=id, hash=hash,
|
||||
username=username, phone=phone, name=name)
|
||||
|
||||
def process_entities(self, tlo: Any) -> None:
|
||||
rows = self._entities_to_rows(tlo)
|
||||
if not rows:
|
||||
return
|
||||
|
||||
for row in rows:
|
||||
self.db.merge(row)
|
||||
self.save()
|
||||
|
||||
def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]:
|
||||
row = self._db_query(self.Entity,
|
||||
self.Entity.phone == key).one_or_none()
|
||||
return (row.id, row.hash) if row else None
|
||||
|
||||
def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]:
|
||||
row = self._db_query(self.Entity,
|
||||
self.Entity.username == key).one_or_none()
|
||||
return (row.id, row.hash) if row else None
|
||||
|
||||
def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]:
|
||||
row = self._db_query(self.Entity,
|
||||
self.Entity.name == key).one_or_none()
|
||||
return (row.id, row.hash) if row else None
|
||||
|
||||
def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]:
|
||||
if exact:
|
||||
query = self._db_query(self.Entity, self.Entity.id == key)
|
||||
else:
|
||||
ids = (
|
||||
utils.get_peer_id(PeerUser(key)),
|
||||
utils.get_peer_id(PeerChat(key)),
|
||||
utils.get_peer_id(PeerChannel(key))
|
||||
)
|
||||
query = self._db_query(self.Entity, self.Entity.id.in_(ids))
|
||||
|
||||
row = query.one_or_none()
|
||||
return (row.id, row.hash) if row else None
|
||||
|
||||
def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]:
|
||||
row = self._db_query(self.SentFile,
|
||||
self.SentFile.md5_digest == md5_digest,
|
||||
self.SentFile.file_size == file_size,
|
||||
self.SentFile.type == _SentFileType.from_type(
|
||||
cls).value).one_or_none()
|
||||
return (row.id, row.hash) if row else None
|
||||
|
||||
def cache_file(self, md5_digest: str, file_size: int,
|
||||
instance: Union[InputDocument, InputPhoto]) -> None:
|
||||
if not isinstance(instance, (InputDocument, InputPhoto)):
|
||||
raise TypeError("Cannot cache {} instance".format(type(instance)))
|
||||
|
||||
self.db.merge(
|
||||
self.SentFile(session_id=self.session_id, md5_digest=md5_digest, file_size=file_size,
|
||||
type=_SentFileType.from_type(type(instance)).value,
|
||||
id=instance.id, hash=instance.access_hash))
|
||||
self.save()
|
203
library/telegram/session_backend/sqlalchemy.py
Normal file
203
library/telegram/session_backend/sqlalchemy.py
Normal file
@ -0,0 +1,203 @@
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sqlalchemy as sql
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
Integer,
|
||||
LargeBinary,
|
||||
String,
|
||||
and_,
|
||||
func,
|
||||
orm,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm.scoping import scoped_session
|
||||
|
||||
from .core import AlchemyCoreSession
|
||||
from .core_postgres import AlchemyPostgresCoreSession
|
||||
from .orm import AlchemySession
|
||||
|
||||
LATEST_VERSION = 2
|
||||
|
||||
|
||||
class AlchemySessionContainer:
|
||||
def __init__(self, engine: Union[sql.engine.Engine, str] = None,
|
||||
session: Optional[Union[orm.Session, scoped_session, bool]] = None,
|
||||
table_prefix: str = "", table_base: Optional[declarative_base] = None,
|
||||
manage_tables: bool = True) -> None:
|
||||
if isinstance(engine, str):
|
||||
engine = sql.create_engine(engine)
|
||||
|
||||
self.db_engine = engine
|
||||
if session is None:
|
||||
db_factory = orm.sessionmaker(bind=self.db_engine)
|
||||
self.db = orm.scoping.scoped_session(db_factory)
|
||||
elif not session:
|
||||
self.db = None
|
||||
else:
|
||||
self.db = session
|
||||
|
||||
table_base = table_base or declarative_base()
|
||||
(self.Version, self.Session, self.Entity,
|
||||
self.SentFile, self.UpdateState) = self.create_table_classes(self.db, table_prefix,
|
||||
table_base)
|
||||
self.alchemy_session_class = AlchemySession
|
||||
if not self.db:
|
||||
# Implicit core mode if there's no ORM session.
|
||||
self.core_mode = True
|
||||
|
||||
if manage_tables:
|
||||
if not self.db:
|
||||
raise ValueError("Can't manage tables without an ORM session.")
|
||||
table_base.metadata.bind = self.db_engine
|
||||
if not self.db_engine.dialect.has_table(self.db_engine,
|
||||
self.Version.__tablename__):
|
||||
table_base.metadata.create_all()
|
||||
self.db.add(self.Version(version=LATEST_VERSION))
|
||||
self.db.commit()
|
||||
else:
|
||||
self.check_and_upgrade_database()
|
||||
|
||||
@property
|
||||
def core_mode(self) -> bool:
|
||||
return self.alchemy_session_class != AlchemySession
|
||||
|
||||
@core_mode.setter
|
||||
def core_mode(self, val: bool) -> None:
|
||||
if val:
|
||||
if self.db_engine.dialect.name == "postgresql":
|
||||
self.alchemy_session_class = AlchemyPostgresCoreSession
|
||||
else:
|
||||
self.alchemy_session_class = AlchemyCoreSession
|
||||
else:
|
||||
if not self.db:
|
||||
raise ValueError("Can't use ORM mode without an ORM session.")
|
||||
self.alchemy_session_class = AlchemySession
|
||||
|
||||
@staticmethod
|
||||
def create_table_classes(db: scoped_session, prefix: str, base: declarative_base
|
||||
) -> Tuple[Any, Any, Any, Any, Any]:
|
||||
qp = db.query_property() if db else None
|
||||
|
||||
class Version(base):
|
||||
query = qp
|
||||
__tablename__ = "{prefix}version".format(prefix=prefix)
|
||||
version = Column(Integer, primary_key=True)
|
||||
|
||||
def __str__(self):
|
||||
return "Version('{}')".format(self.version)
|
||||
|
||||
class Session(base):
|
||||
query = qp
|
||||
__tablename__ = '{prefix}sessions'.format(prefix=prefix)
|
||||
|
||||
session_id = Column(String(255), primary_key=True)
|
||||
dc_id = Column(Integer, primary_key=True)
|
||||
server_address = Column(String(255))
|
||||
port = Column(Integer)
|
||||
auth_key = Column(LargeBinary)
|
||||
|
||||
def __str__(self):
|
||||
return "Session('{}', {}, '{}', {}, {})".format(self.session_id, self.dc_id,
|
||||
self.server_address, self.port,
|
||||
self.auth_key)
|
||||
|
||||
class Entity(base):
|
||||
query = qp
|
||||
__tablename__ = '{prefix}entities'.format(prefix=prefix)
|
||||
|
||||
session_id = Column(String(255), primary_key=True)
|
||||
id = Column(BigInteger, primary_key=True)
|
||||
hash = Column(BigInteger, nullable=False)
|
||||
username = Column(String(32))
|
||||
phone = Column(BigInteger)
|
||||
name = Column(String(255))
|
||||
|
||||
def __str__(self):
|
||||
return "Entity('{}', {}, {}, '{}', '{}', '{}')".format(self.session_id, self.id,
|
||||
self.hash, self.username,
|
||||
self.phone, self.name)
|
||||
|
||||
class SentFile(base):
|
||||
query = qp
|
||||
__tablename__ = '{prefix}sent_files'.format(prefix=prefix)
|
||||
|
||||
session_id = Column(String(255), primary_key=True)
|
||||
md5_digest = Column(LargeBinary, primary_key=True)
|
||||
file_size = Column(Integer, primary_key=True)
|
||||
type = Column(Integer, primary_key=True)
|
||||
id = Column(BigInteger)
|
||||
hash = Column(BigInteger)
|
||||
|
||||
def __str__(self):
|
||||
return "SentFile('{}', {}, {}, {}, {}, {})".format(self.session_id,
|
||||
self.md5_digest, self.file_size,
|
||||
self.type, self.id, self.hash)
|
||||
|
||||
class UpdateState(base):
|
||||
query = qp
|
||||
__tablename__ = "{prefix}update_state".format(prefix=prefix)
|
||||
|
||||
session_id = Column(String(255), primary_key=True)
|
||||
entity_id = Column(BigInteger, primary_key=True)
|
||||
pts = Column(BigInteger)
|
||||
qts = Column(BigInteger)
|
||||
date = Column(BigInteger)
|
||||
seq = Column(BigInteger)
|
||||
unread_count = Column(Integer)
|
||||
|
||||
return Version, Session, Entity, SentFile, UpdateState
|
||||
|
||||
def _add_column(self, table: Any, column: Column) -> None:
|
||||
column_name = column.compile(dialect=self.db_engine.dialect)
|
||||
column_type = column.type.compile(self.db_engine.dialect)
|
||||
self.db_engine.execute("ALTER TABLE {} ADD COLUMN {} {}".format(
|
||||
table.__tablename__, column_name, column_type))
|
||||
|
||||
def check_and_upgrade_database(self) -> None:
|
||||
row = self.Version.query.all()
|
||||
version = row[0].version if row else 1
|
||||
if version == LATEST_VERSION:
|
||||
return
|
||||
|
||||
self.Version.query.delete()
|
||||
|
||||
if version == 1:
|
||||
self.UpdateState.__table__.create(self.db_engine)
|
||||
version = 3
|
||||
elif version == 2:
|
||||
self._add_column(self.UpdateState, Column(type=Integer, name="unread_count"))
|
||||
|
||||
self.db.add(self.Version(version=version))
|
||||
self.db.commit()
|
||||
|
||||
def new_session(self, session_id: str) -> 'AlchemySession':
|
||||
return self.alchemy_session_class(self, session_id)
|
||||
|
||||
def has_session(self, session_id: str) -> bool:
|
||||
if self.core_mode:
|
||||
t = self.Session.__table__
|
||||
rows = self.db_engine.execute(select([func.count(t.c.auth_key)])
|
||||
.where(and_(t.c.session_id == session_id,
|
||||
t.c.auth_key != b'')))
|
||||
try:
|
||||
count, = next(rows)
|
||||
return count > 0
|
||||
except StopIteration:
|
||||
return False
|
||||
else:
|
||||
return self.Session.query.filter(self.Session.session_id == session_id).count() > 0
|
||||
|
||||
def list_sessions(self):
|
||||
return
|
||||
|
||||
def save(self) -> None:
|
||||
if self.db:
|
||||
self.db.commit()
|
45
library/telegram/utils.py
Normal file
45
library/telegram/utils.py
Normal file
@ -0,0 +1,45 @@
|
||||
import logging
|
||||
import traceback
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from telethon import (
|
||||
errors,
|
||||
events,
|
||||
)
|
||||
|
||||
from .base import RequestContext
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def safe_execution(
|
||||
request_context: RequestContext,
|
||||
on_fail: Optional[Callable[[], Awaitable]] = None,
|
||||
):
|
||||
try:
|
||||
try:
|
||||
yield
|
||||
except events.StopPropagation:
|
||||
raise
|
||||
except (
|
||||
errors.UserIsBlockedError,
|
||||
errors.QueryIdInvalidError,
|
||||
errors.MessageDeleteForbiddenError,
|
||||
errors.MessageIdInvalidError,
|
||||
errors.MessageNotModifiedError,
|
||||
errors.ChatAdminRequiredError,
|
||||
) as e:
|
||||
request_context.error_log(e, level=logging.WARNING)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
request_context.error_log(e)
|
||||
if on_fail:
|
||||
await on_fail()
|
||||
except events.StopPropagation:
|
||||
raise
|
||||
except Exception as e:
|
||||
request_context.error_log(e)
|
@ -199,8 +199,9 @@ class FillDocumentOperationUpdateDocumentScimagPbFromExternalSourceAction(BaseAc
|
||||
super().__init__()
|
||||
self.crossref_client = CrossrefClient(
|
||||
delay=1.0 / crossref['rps'],
|
||||
max_retries=60,
|
||||
max_retries=crossref.get('max_retries', 15),
|
||||
proxy_url=crossref.get('proxy_url'),
|
||||
retry_delay=crossref.get('retry_delay', 0.5),
|
||||
timeout=crossref.get('timeout'),
|
||||
user_agent=crossref.get('user_agent'),
|
||||
)
|
||||
|
@ -41,7 +41,7 @@ py3_image(
|
||||
requirement("aiokit"),
|
||||
"//library/configurator",
|
||||
"//library/logging",
|
||||
"//library/metrics_server",
|
||||
,
|
||||
"//library/telegram",
|
||||
"//nexus/hub/aioclient",
|
||||
"//nexus/meta_api/aioclient",
|
||||
|
@ -5,4 +5,5 @@ The bot requires three other essential parts:
|
||||
- Nexus Hub API (managing files)
|
||||
- Nexus Meta API (rescoring API for Summa Search server)
|
||||
|
||||
or their substitutions
|
||||
or their substitutions
|
||||
|
||||
|
@ -9,5 +9,6 @@ py_binary(
|
||||
deps = [
|
||||
"@bazel_tools//tools/build_defs/pkg:archive",
|
||||
requirement("python-gflags"),
|
||||
requirement("six"),
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user