hyperboria/library/telegram/session_backend/orm.py

171 lines
6.5 KiB
Python
Raw Normal View History

import datetime
from typing import (
TYPE_CHECKING,
Any,
Optional,
Tuple,
Union,
)
from sqlalchemy import orm
from telethon import utils
from telethon.crypto import AuthKey
from telethon.sessions.memory import (
MemorySession,
_SentFileType,
)
from telethon.tl.types import (
InputDocument,
InputPhoto,
PeerChannel,
PeerChat,
PeerUser,
updates,
)
if TYPE_CHECKING:
from .sqlalchemy import AlchemySessionContainer
class AlchemySession(MemorySession):
def __init__(self, container: 'AlchemySessionContainer', session_id: str) -> None:
super().__init__()
self.container = container
self.db = container.db
self.engine = container.db_engine
self.Version, self.Session, self.Entity, self.SentFile, self.UpdateState = (
container.Version, container.Session, container.Entity,
container.SentFile, container.UpdateState)
self.session_id = session_id
self._load_session()
def _load_session(self) -> None:
sessions = self._db_query(self.Session).all()
session = sessions[0] if sessions else None
if session:
self._dc_id = session.dc_id
self._server_address = session.server_address
self._port = session.port
self._auth_key = AuthKey(data=session.auth_key)
def clone(self, to_instance=None) -> MemorySession:
return super().clone(MemorySession())
def _get_auth_key(self) -> Optional[AuthKey]:
sessions = self._db_query(self.Session).all()
session = sessions[0] if sessions else None
if session and session.auth_key:
return AuthKey(data=session.auth_key)
return None
def set_dc(self, dc_id: str, server_address: str, port: int) -> None:
super().set_dc(dc_id, server_address, port)
self._update_session_table()
self._auth_key = self._get_auth_key()
def get_update_state(self, entity_id: int) -> Optional[updates.State]:
row = self.UpdateState.query.get((self.session_id, entity_id))
if row:
date = datetime.datetime.utcfromtimestamp(row.date)
return updates.State(row.pts, row.qts, date, row.seq, row.unread_count)
return None
def set_update_state(self, entity_id: int, row: Any) -> None:
if row:
self.db.merge(self.UpdateState(session_id=self.session_id, entity_id=entity_id,
pts=row.pts, qts=row.qts, date=row.date.timestamp(),
seq=row.seq,
unread_count=row.unread_count))
self.save()
@MemorySession.auth_key.setter
def auth_key(self, value: AuthKey) -> None:
self._auth_key = value
self._update_session_table()
def _update_session_table(self) -> None:
self.Session.query.filter(self.Session.session_id == self.session_id).delete()
self.db.add(self.Session(session_id=self.session_id, dc_id=self._dc_id,
server_address=self._server_address, port=self._port,
auth_key=(self._auth_key.key if self._auth_key else b'')))
def _db_query(self, dbclass: Any, *args: Any) -> orm.Query:
return dbclass.query.filter(
dbclass.session_id == self.session_id, *args
)
def save(self) -> None:
self.container.save()
def close(self) -> None:
# Nothing to do here, connection is managed by AlchemySessionContainer.
pass
def delete(self) -> None:
self._db_query(self.Session).delete()
self._db_query(self.Entity).delete()
self._db_query(self.SentFile).delete()
self._db_query(self.UpdateState).delete()
def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str
) -> Any:
return self.Entity(session_id=self.session_id, id=id, hash=hash,
username=username, phone=phone, name=name)
def process_entities(self, tlo: Any) -> None:
rows = self._entities_to_rows(tlo)
if not rows:
return
for row in rows:
self.db.merge(row)
self.save()
def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]:
row = self._db_query(self.Entity,
self.Entity.phone == key).one_or_none()
return (row.id, row.hash) if row else None
def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]:
row = self._db_query(self.Entity,
self.Entity.username == key).one_or_none()
return (row.id, row.hash) if row else None
def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]:
row = self._db_query(self.Entity,
self.Entity.name == key).one_or_none()
return (row.id, row.hash) if row else None
def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]:
if exact:
query = self._db_query(self.Entity, self.Entity.id == key)
else:
ids = (
utils.get_peer_id(PeerUser(key)),
utils.get_peer_id(PeerChat(key)),
utils.get_peer_id(PeerChannel(key))
)
query = self._db_query(self.Entity, self.Entity.id.in_(ids))
row = query.one_or_none()
return (row.id, row.hash) if row else None
def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]:
row = self._db_query(self.SentFile,
self.SentFile.md5_digest == md5_digest,
self.SentFile.file_size == file_size,
self.SentFile.type == _SentFileType.from_type(
cls).value).one_or_none()
return (row.id, row.hash) if row else None
def cache_file(self, md5_digest: str, file_size: int,
instance: Union[InputDocument, InputPhoto]) -> None:
if not isinstance(instance, (InputDocument, InputPhoto)):
raise TypeError("Cannot cache {} instance".format(type(instance)))
self.db.merge(
self.SentFile(session_id=self.session_id, md5_digest=md5_digest, file_size=file_size,
type=_SentFileType.from_type(type(instance)).value,
id=instance.id, hash=instance.access_hash))
self.save()