pyrogram-rockserver-storage/pyrogram_rockserver_storage/__init__.py
2024-03-30 22:26:50 +01:00

321 lines
13 KiB
Python

__author__ = 'Andrea Cavalli'
__version__ = '0.2'
import asyncio
import pathlib
import time
import urllib
import warnings
from enum import Enum
from itertools import chain
from string import digits
from typing import Any, List, Tuple, Dict, Optional
from pyrogram import raw, utils
from pyrogram.storage import Storage
from thriftpy2.contrib.aio.socket import TAsyncSocket
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
from lru import LRU
import thriftpy2
import bson
from thriftpy2.contrib.aio.client import TAsyncClient
from pyrogram_rockserver_storage.TParallelAsyncClient import TParallelAsyncClient
SESSION_KEY = [bytes([0])]
DIGITS = set(digits)
TRANSPORT_FACTORY = TAsyncFramedTransportFactory()
PROTOCOL_FACTORY = TAsyncBinaryProtocolFactory()
rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift")
class PeerType(Enum):
""" Pyrogram peer types """
USER = 'user'
BOT = 'bot'
GROUP = 'group'
CHANNEL = 'channel'
SUPERGROUP = 'supergroup'
def encode_peer_info(access_hash: int, peer_type: str, username: str, phone_number: str, last_update_on: int):
return {"access_hash": access_hash, "peer_type": peer_type, "username": username, "phone_number": phone_number, "last_update_on": last_update_on}
def decode_peer_info(peer_id: int, value):
return {"id": peer_id, "access_hash": value["access_hash"], "peer_type": value["peer_type"], "username": value["username"], "phone_number": ["phone_number"], "last_update_on": value["last_update_on"]} if value is not None else None
def get_input_peer(peer):
""" This function is almost blindly copied from pyrogram sqlite storage"""
peer_id, peer_type, access_hash = peer['id'], peer['peer_type'], peer['access_hash']
if peer_type in {PeerType.USER.value, PeerType.BOT.value}:
return raw.types.InputPeerUser(user_id=peer_id, access_hash=access_hash)
if peer_type == PeerType.GROUP.value:
return raw.types.InputPeerChat(chat_id=-peer_id)
if peer_type in {PeerType.CHANNEL.value, PeerType.SUPERGROUP.value}:
return raw.types.InputPeerChannel(
channel_id=utils.get_channel_id(peer_id),
access_hash=access_hash
)
raise ValueError(f"Invalid peer type: {peer['type']}")
async def fetchone(client: TAsyncClient, column: int, keys: Any) -> Optional[Dict]:
""" Small helper - fetches a single row from provided query """
value = (await client.get(0, column, keys)).value
value = bson.loads(value) if value else None
return dict(value) if value else None
class RockServerStorage(Storage):
"""
Implementation of RockServer storage.
Example usage:
>>> from pyrogram import Client
>>>
>>> session = RockServerStorage(hostname=..., port=5332, user_id=..., session_unique_name=..., save_user_peers=...)
>>> pyrogram = Client(session_name=session)
>>> await pyrogram.connect()
>>> ...
"""
USERNAME_TTL = 8 * 60 * 60 # pyrogram constant
def __init__(self,
hostname: str,
port: int,
session_unique_name: str,
save_user_peers: bool):
"""
:param hostname: rocksdb hostname
:param port: rocksdb port
:param session_unique_name: telegram session phone
"""
self._session_col = None
self._peer_col = None
self._session_id = f'{session_unique_name}'
self._session_data = {"dc_id": 2, "api_id": None, "test_mode": None, "auth_key": None, "date": 0, "user_id": None, "is_bot": None, "phone": None}
self._client = None
self._hostname = hostname
self._port = port
self._save_user_peers = save_user_peers
self._username_to_id = LRU(100_000)
self._phone_to_id = LRU(100_000)
super().__init__(name=self._session_id)
async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None,
proto_factory=TAsyncBinaryProtocolFactory(),
trans_factory=TAsyncBufferedTransportFactory(),
timeout=3000, connect_timeout=None,
cafile=None, ssl_context=None,
certfile=None, keyfile=None,
validate=True, url='',
socket_timeout=None):
if socket_timeout is not None:
warnings.warn(
"The 'socket_timeout' argument is deprecated. "
"Please use 'timeout' instead.",
DeprecationWarning,
)
timeout = socket_timeout
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TAsyncSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket = TAsyncSocket(
host, port,
socket_timeout=timeout, connect_timeout=connect_timeout,
cafile=cafile, ssl_context=ssl_context,
certfile=certfile, keyfile=keyfile, validate=validate)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")
transport = trans_factory.get_transport(socket)
protocol = proto_factory.get_protocol(transport)
await transport.open()
return TParallelAsyncClient(service, protocol)
async def open(self):
""" Initialize pyrogram session"""
self._client = await self.make_parallel_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000)
# Column('dc_id', BIGINT, primary_key=True),
# Column('api_id', BIGINT),
# Column('test_mode', Boolean),
# Column('auth_key', BYTEA),
# Column('date', BIGINT, nullable=False),
# Column('user_id', BIGINT),
# Column('is_bot', Boolean),
# Column('phone', String(length=50)
self._session_col = await self._client.createColumn(name=f'pyrogram_session_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[1], variableTailKeys=[], hasValue=True))
# Column('id', BIGINT),
# Column('access_hash', BIGINT),
# Column('type', String, nullable=False),
# Column('username', String),
# Column('phone_number', String),
# Column('last_update_on', BIGINT),
self._peer_col = await self._client.createColumn(name=f'peers_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[8], variableTailKeys=[], hasValue=True))
fetched_session_data = await fetchone(self._client, self._session_col, SESSION_KEY)
self._session_data = fetched_session_data if fetched_session_data is not None else self._session_data
async def save(self):
""" On save we update the date """
await self.date(int(time.time()))
async def close(self):
""" Close transport """
await self._client.close()
async def delete(self):
""" Delete all the tables and indexes """
await self._client.deleteColumn(self._session_id)
await self._client.deleteColumn(self._peer_col)
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
""" Copied and adopted from pyro sqlite storage"""
if not peers:
return
now = int(time.time())
deduplicated_peers = []
seen_ids = set()
# deduplicate peers to avoid possible `CardinalityViolation` error
for peer in peers:
if not self._save_user_peers and peer[2] == "user":
continue
peer_id, *_ = peer
if peer_id in seen_ids:
continue
seen_ids.add(peer_id)
# enrich peer with timestamp and append
deduplicated_peers.append(tuple(chain(peer, (now,))))
# construct insert query
if deduplicated_peers:
keys_multi = []
value_multi = []
for deduplicated_peer in deduplicated_peers:
peer_id = deduplicated_peer[0]
username = deduplicated_peer[3]
phone_number = deduplicated_peer[4]
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], username,
phone_number, deduplicated_peer[5])
value = bson.dumps(value_tuple)
keys_multi.append(keys)
value_multi.append(value)
if username is not None:
self._username_to_id[username] = peer_id
if phone_number is not None:
self._phone_to_id[phone_number] = peer_id
await self._client.putMulti(0, self._peer_col, keys_multi, value_multi)
async def get_peer_by_id(self, peer_id: int):
if isinstance(peer_id, str) or (not self._save_user_peers and peer_id > 0):
raise KeyError(f"ID not found: {peer_id}")
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
encoded_value = await fetchone(self._client, self._peer_col, keys)
value_tuple = decode_peer_info(peer_id, encoded_value)
if value_tuple is None:
raise KeyError(f"ID not found: {peer_id}")
return get_input_peer(value_tuple)
async def get_peer_by_username(self, username: str):
peer_id = self._username_to_id.get(username)
if peer_id is None:
raise KeyError(f"Username not found: {username}")
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
encoded_value = await fetchone(self._client, self._peer_col, keys)
value_tuple = decode_peer_info(peer_id, encoded_value)
if value_tuple is None:
raise KeyError(f"Username not found: {username}")
if int(time.time() - value_tuple['last_update_on']) > self.USERNAME_TTL:
raise KeyError(f"Username expired: {username}")
return get_input_peer(value_tuple)
async def get_peer_by_phone_number(self, phone_number: str):
peer_id = self._phone_to_id.get(phone_number)
if peer_id is None:
raise KeyError(f"Phone number not found: {phone_number}")
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
encoded_value = await fetchone(self._client, self._peer_col, keys)
value_tuple = decode_peer_info(peer_id, encoded_value)
return get_input_peer(value_tuple)
async def _set(self, column, value: Any):
update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY)
try:
decoded_bson_session_data = bson.loads(update_begin.previous) if update_begin.previous is not None else None
session_data = decoded_bson_session_data if decoded_bson_session_data is not None else self._session_data
session_data[column] = value
encoded_session_data = bson.dumps(session_data)
await self._client.put(update_begin.updateId, self._session_col, SESSION_KEY, encoded_session_data)
except:
print("Failed to update session in rocksdb, cancelling the update transaction...")
try:
await self._client.closeFailedUpdate(update_begin.updateId)
except:
pass
self._session_data[column] = value # update local copy
async def _accessor(self, column, value: Any = object):
return self._session_data[column] if value == object else await self._set(column, value)
async def dc_id(self, value: int = object):
return await self._accessor('dc_id', value)
async def api_id(self, value: int = object):
return await self._accessor('api_id', value)
async def test_mode(self, value: bool = object):
return await self._accessor('test_mode', value)
async def auth_key(self, value: bytes = object):
return await self._accessor('auth_key', value)
async def date(self, value: int = object):
return await self._accessor('date', value)
async def user_id(self, value: int = object):
return await self._accessor('user_id', value)
async def is_bot(self, value: bool = object):
return await self._accessor('is_bot', value)