mirror of
https://github.com/nexus-stc/hyperboria
synced 2024-12-20 16:47:56 +01:00
cca9e8be47
- feat: Set up delays for Crossref API GitOrigin-RevId: 3448e1c2a9fdfca2f8bf95d37e1d62cc5a221577
171 lines
6.5 KiB
Python
171 lines
6.5 KiB
Python
import datetime
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
from sqlalchemy import orm
|
|
from telethon import utils
|
|
from telethon.crypto import AuthKey
|
|
from telethon.sessions.memory import (
|
|
MemorySession,
|
|
_SentFileType,
|
|
)
|
|
from telethon.tl.types import (
|
|
InputDocument,
|
|
InputPhoto,
|
|
PeerChannel,
|
|
PeerChat,
|
|
PeerUser,
|
|
updates,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from .sqlalchemy import AlchemySessionContainer
|
|
|
|
|
|
class AlchemySession(MemorySession):
|
|
def __init__(self, container: 'AlchemySessionContainer', session_id: str) -> None:
|
|
super().__init__()
|
|
self.container = container
|
|
self.db = container.db
|
|
self.engine = container.db_engine
|
|
self.Version, self.Session, self.Entity, self.SentFile, self.UpdateState = (
|
|
container.Version, container.Session, container.Entity,
|
|
container.SentFile, container.UpdateState)
|
|
self.session_id = session_id
|
|
self._load_session()
|
|
|
|
def _load_session(self) -> None:
|
|
sessions = self._db_query(self.Session).all()
|
|
session = sessions[0] if sessions else None
|
|
if session:
|
|
self._dc_id = session.dc_id
|
|
self._server_address = session.server_address
|
|
self._port = session.port
|
|
self._auth_key = AuthKey(data=session.auth_key)
|
|
|
|
def clone(self, to_instance=None) -> MemorySession:
|
|
return super().clone(MemorySession())
|
|
|
|
def _get_auth_key(self) -> Optional[AuthKey]:
|
|
sessions = self._db_query(self.Session).all()
|
|
session = sessions[0] if sessions else None
|
|
if session and session.auth_key:
|
|
return AuthKey(data=session.auth_key)
|
|
return None
|
|
|
|
def set_dc(self, dc_id: str, server_address: str, port: int) -> None:
|
|
super().set_dc(dc_id, server_address, port)
|
|
self._update_session_table()
|
|
self._auth_key = self._get_auth_key()
|
|
|
|
def get_update_state(self, entity_id: int) -> Optional[updates.State]:
|
|
row = self.UpdateState.query.get((self.session_id, entity_id))
|
|
if row:
|
|
date = datetime.datetime.utcfromtimestamp(row.date)
|
|
return updates.State(row.pts, row.qts, date, row.seq, row.unread_count)
|
|
return None
|
|
|
|
def set_update_state(self, entity_id: int, row: Any) -> None:
|
|
if row:
|
|
self.db.merge(self.UpdateState(session_id=self.session_id, entity_id=entity_id,
|
|
pts=row.pts, qts=row.qts, date=row.date.timestamp(),
|
|
seq=row.seq,
|
|
unread_count=row.unread_count))
|
|
self.save()
|
|
|
|
@MemorySession.auth_key.setter
|
|
def auth_key(self, value: AuthKey) -> None:
|
|
self._auth_key = value
|
|
self._update_session_table()
|
|
|
|
def _update_session_table(self) -> None:
|
|
self.Session.query.filter(self.Session.session_id == self.session_id).delete()
|
|
self.db.add(self.Session(session_id=self.session_id, dc_id=self._dc_id,
|
|
server_address=self._server_address, port=self._port,
|
|
auth_key=(self._auth_key.key if self._auth_key else b'')))
|
|
|
|
def _db_query(self, dbclass: Any, *args: Any) -> orm.Query:
|
|
return dbclass.query.filter(
|
|
dbclass.session_id == self.session_id, *args
|
|
)
|
|
|
|
def save(self) -> None:
|
|
self.container.save()
|
|
|
|
def close(self) -> None:
|
|
# Nothing to do here, connection is managed by AlchemySessionContainer.
|
|
pass
|
|
|
|
def delete(self) -> None:
|
|
self._db_query(self.Session).delete()
|
|
self._db_query(self.Entity).delete()
|
|
self._db_query(self.SentFile).delete()
|
|
self._db_query(self.UpdateState).delete()
|
|
|
|
def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str
|
|
) -> Any:
|
|
return self.Entity(session_id=self.session_id, id=id, hash=hash,
|
|
username=username, phone=phone, name=name)
|
|
|
|
def process_entities(self, tlo: Any) -> None:
|
|
rows = self._entities_to_rows(tlo)
|
|
if not rows:
|
|
return
|
|
|
|
for row in rows:
|
|
self.db.merge(row)
|
|
self.save()
|
|
|
|
def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]:
|
|
row = self._db_query(self.Entity,
|
|
self.Entity.phone == key).one_or_none()
|
|
return (row.id, row.hash) if row else None
|
|
|
|
def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]:
|
|
row = self._db_query(self.Entity,
|
|
self.Entity.username == key).one_or_none()
|
|
return (row.id, row.hash) if row else None
|
|
|
|
def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]:
|
|
row = self._db_query(self.Entity,
|
|
self.Entity.name == key).one_or_none()
|
|
return (row.id, row.hash) if row else None
|
|
|
|
def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]:
|
|
if exact:
|
|
query = self._db_query(self.Entity, self.Entity.id == key)
|
|
else:
|
|
ids = (
|
|
utils.get_peer_id(PeerUser(key)),
|
|
utils.get_peer_id(PeerChat(key)),
|
|
utils.get_peer_id(PeerChannel(key))
|
|
)
|
|
query = self._db_query(self.Entity, self.Entity.id.in_(ids))
|
|
|
|
row = query.one_or_none()
|
|
return (row.id, row.hash) if row else None
|
|
|
|
def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]:
|
|
row = self._db_query(self.SentFile,
|
|
self.SentFile.md5_digest == md5_digest,
|
|
self.SentFile.file_size == file_size,
|
|
self.SentFile.type == _SentFileType.from_type(
|
|
cls).value).one_or_none()
|
|
return (row.id, row.hash) if row else None
|
|
|
|
def cache_file(self, md5_digest: str, file_size: int,
|
|
instance: Union[InputDocument, InputPhoto]) -> None:
|
|
if not isinstance(instance, (InputDocument, InputPhoto)):
|
|
raise TypeError("Cannot cache {} instance".format(type(instance)))
|
|
|
|
self.db.merge(
|
|
self.SentFile(session_id=self.session_id, md5_digest=md5_digest, file_size=file_size,
|
|
type=_SentFileType.from_type(type(instance)).value,
|
|
id=instance.id, hash=instance.access_hash))
|
|
self.save()
|