commit 9563e40ba676fdddabba62d2ad9cc16785a3d85c Author: Dmytro Smyk Date: Thu Apr 15 01:38:37 2021 +0300 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..99607af --- /dev/null +++ b/.gitignore @@ -0,0 +1,62 @@ +# Created by .ignore support plugin (hsz.mobi) +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 +.idea + +# User-specific stuff +.idea/**/tasks.xml +.idea/**/dictionaries +.idea/**/shelf + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# CMake +cmake-build-debug/ +cmake-build-release/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests +*.pyc +dist +*.egg-info +MANIFEST + +.venv + +.vscode/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..a9a5a85 --- /dev/null +++ b/README.md @@ -0,0 +1,16 @@ +Pyrogram RockServer storage +====================== + +Usage +----- + +```python + +from pyrogram import Client +from pyrogram_rockserver_storage import RockServerStorage + +session = RockServerStorage(hostname=..., port=..., session_unique_name=..., save_user_peers=...) +pyrogram = Client(session) +await pyrogram.connect() + +``` \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..af2c143 --- /dev/null +++ b/main.py @@ -0,0 +1,18 @@ +from pyrogram_rockserver_storage import RockServerStorage +import asyncio + + +async def main(): + await storage.open() + await storage.update_peers([(-1001946993950, 2, "channel", "4", "5"), (-6846993950, 2, "group", "4", "5")]) + print("Peer channel", await storage.get_peer_by_id(-1001946993950)) + print("Peer group", await storage.get_peer_by_id(-6846993950)) + print("dc_id", await storage.dc_id()) + await storage.dc_id(4) + +print("ciao") +storage = RockServerStorage(save_user_peers=False, hostname="127.0.0.1", port=5332, session_unique_name="test_99") +print("ciao2") +loop = asyncio.get_event_loop() +loop.run_until_complete(main()) +print("ciao3") diff --git a/pyrogram_rockserver_storage/__init__.py b/pyrogram_rockserver_storage/__init__.py new file mode 100644 index 0000000..93dc66a --- /dev/null +++ b/pyrogram_rockserver_storage/__init__.py @@ -0,0 +1,242 @@ +__author__ = 'Andrea Cavalli' +__version__ = '0.1' + +import asyncio +import pathlib +import time +from enum import Enum +from itertools import chain +from string import digits +from typing import Any, List, Tuple, Dict, Optional + +from pyrogram import raw, utils +from pyrogram.storage import Storage +from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory +from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory + +from thriftpy2.transport.framed import TFramedTransportFactory + +import thriftpy2 + +import bson +from thriftpy2.contrib.aio.client import TAsyncClient + +SESSION_KEY = [bytes([0])] +DIGITS = set(digits) +TRANSPORT_FACTORY = TAsyncFramedTransportFactory() +PROTOCOL_FACTORY = TAsyncBinaryProtocolFactory() + +rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift") + +from thriftpy2.rpc import make_aio_client + + +class PeerType(Enum): + """ Pyrogram peer types """ + USER = 'user' + BOT = 'bot' + GROUP = 'group' + CHANNEL = 'channel' + SUPERGROUP = 'supergroup' + +def encode_peer_info(access_hash: int, peer_type: str, username: str, phone_number: str, last_update_on: int): + return {"access_hash": access_hash, "peer_type": peer_type, "username": username, "phone_number": phone_number, "last_update_on": last_update_on} + +def decode_peer_info(peer_id: int, value): + return {"id": peer_id, "access_hash": value["access_hash"], "peer_type": value["peer_type"]} if value is not None else None + +def get_input_peer(peer): + """ This function is almost blindly copied from pyrogram sqlite storage""" + peer_id, peer_type, access_hash = peer['id'], peer['peer_type'], peer['access_hash'] + + if peer_type in {PeerType.USER.value, PeerType.BOT.value}: + return raw.types.InputPeerUser(user_id=peer_id, access_hash=access_hash) + + if peer_type == PeerType.GROUP.value: + return raw.types.InputPeerChat(chat_id=-peer_id) + + if peer_type in {PeerType.CHANNEL.value, PeerType.SUPERGROUP.value}: + return raw.types.InputPeerChannel( + channel_id=utils.get_channel_id(peer_id), + access_hash=access_hash + ) + + raise ValueError(f"Invalid peer type: {peer['type']}") + + +async def fetchone(client: TAsyncClient, column: int, keys: Any) -> Optional[Dict]: + """ Small helper - fetches a single row from provided query """ + value = (await client.get(0, column, keys)).value + value = bson.loads(value) if value else None + return dict(value) if value else None + + +class RockServerStorage(Storage): + """ + Implementation of RockServer storage. + + Example usage: + + >>> from pyrogram import Client + >>> + >>> session = RockServerStorage(hostname=..., port=5332, user_id=..., session_unique_name=..., save_user_peers=...) + >>> pyrogram = Client(session_name=session) + >>> await pyrogram.connect() + >>> ... + + """ + + USERNAME_TTL = 8 * 60 * 60 # pyrogram constant + + def __init__(self, + hostname: str, + port: int, + session_unique_name: str, + save_user_peers: bool): + """ + :param hostname: rocksdb hostname + :param port: rocksdb port + :param session_unique_name: telegram session phone + """ + self._session_col = None + self._peer_col = None + self._session_id = f'{session_unique_name}' + self._session_data = None + self._client = None + self._hostname = hostname + self._port = port + + self._save_user_peers = save_user_peers + + super().__init__(name=self._session_id) + + async def open(self): + """ Initialize pyrogram session""" + self._client = await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000) + + # Column('dc_id', BIGINT, primary_key=True), + # Column('api_id', BIGINT), + # Column('test_mode', Boolean), + # Column('auth_key', BYTEA), + # Column('date', BIGINT, nullable=False), + # Column('user_id', BIGINT), + # Column('is_bot', Boolean), + # Column('phone', String(length=50) + self._session_col = await self._client.createColumn(name=f'pyrogram_session_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[1], variableTailKeys=[], hasValue=True)) + + # Column('id', BIGINT), + # Column('access_hash', BIGINT), + # Column('type', String, nullable=False), + # Column('username', String), + # Column('phone_number', String), + # Column('last_update_on', BIGINT), + self._peer_col = await self._client.createColumn(name=f'peers_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[8], variableTailKeys=[], hasValue=True)) + + self._session_data = await fetchone(self._client, self._session_col, SESSION_KEY) + if self._session_data is None: + self._session_data = {"dc_id": 2, "api_id": None, "test_mode": None, "auth_key": None, "date": 0, "user_id": None, "is_bot": None, "phone": None} + + async def save(self): + """ On save we update the date """ + await self.date(int(time.time())) + + async def close(self): + """ Close transport """ + await self._client.close() + + async def delete(self): + """ Delete all the tables and indexes """ + await self._client.deleteColumn(self._session_id) + await self._client.deleteColumn(self._peer_col) + + async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): + """ Copied and adopted from pyro sqlite storage""" + if not peers: + return + + now = int(time.time()) + deduplicated_peers = [] + seen_ids = set() + + # deduplicate peers to avoid possible `CardinalityViolation` error + for peer in peers: + if not self._save_user_peers and peer[2] == "user": + continue + peer_id, *_ = peer + if peer_id in seen_ids: + continue + seen_ids.add(peer_id) + # enrich peer with timestamp and append + deduplicated_peers.append(tuple(chain(peer, (now,)))) + + # construct insert query + if deduplicated_peers: + keys_multi = [] + value_multi = [] + for deduplicated_peer in deduplicated_peers: + keys = [deduplicated_peer[0].to_bytes(8, byteorder='big', signed=True)] + value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], deduplicated_peer[3], + deduplicated_peer[4], deduplicated_peer[5]) + value = bson.dumps(value_tuple) + keys_multi.append(keys) + value_multi.append(value) + + await self._client.putMulti(0, self._peer_col, keys_multi, value_multi) + + async def get_peer_by_id(self, peer_id: int): + if isinstance(peer_id, str) or (not self._save_user_peers and peer_id > 0): + raise KeyError(f"ID not found: {peer_id}") + + keys = [peer_id.to_bytes(8, byteorder='big', signed=True)] + encoded_value = await fetchone(self._client, self._peer_col, keys) + value_tuple = decode_peer_info(peer_id, encoded_value) + if value_tuple is None: + raise KeyError(f"ID not found: {peer_id}") + + return get_input_peer(value_tuple) + + async def get_peer_by_username(self, username: str): + raise KeyError("get_peer_by_username is not supported with rocksdb storage") + + async def get_peer_by_phone_number(self, phone_number: str): + raise KeyError("get_peer_by_username is not supported with rocksdb storage") + + async def _set(self, column, value: Any): + update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY) + try: + session_data = bson.loads(update_begin.previous) if update_begin.previous is not None else self._session_data + session_data[column] = value + encoded_session_data = bson.dumps(session_data) + await self._client.put(update_begin.updateId, self._session_col, SESSION_KEY, encoded_session_data) + except: + print("Failed to update session in rocksdb, cancelling the update transaction...") + try: + await self._client.closeFailedUpdate(update_begin.updateId) + except: + pass + self._session_data[column] = value # update local copy + + async def _accessor(self, column, value: Any = object): + return self._session_data[column] if value == object else await self._set(column, value) + + async def dc_id(self, value: int = object): + return await self._accessor('dc_id', value) + + async def api_id(self, value: int = object): + return await self._accessor('api_id', value) + + async def test_mode(self, value: bool = object): + return await self._accessor('test_mode', value) + + async def auth_key(self, value: bytes = object): + return await self._accessor('auth_key', value) + + async def date(self, value: int = object): + return await self._accessor('date', value) + + async def user_id(self, value: int = object): + return await self._accessor('user_id', value) + + async def is_bot(self, value: bool = object): + return await self._accessor('is_bot', value) + diff --git a/pyrogram_rockserver_storage/rocksdb.thrift b/pyrogram_rockserver_storage/rocksdb.thrift new file mode 100644 index 0000000..35ddf5b --- /dev/null +++ b/pyrogram_rockserver_storage/rocksdb.thrift @@ -0,0 +1,87 @@ +namespace java it.cavallium.rockserver.core.common.api + +struct ColumnSchema { + 1: list fixedKeys, + 2: list variableTailKeys, + 3: bool hasValue +} + +enum ColumnHashType { + XXHASH32 = 1, + XXHASH8 = 2, + ALLSAME8 = 3 +} + +enum Operation { + NOTHING = 1, + PREVIOUS = 2, + CURRENT = 3, + FOR_UPDATE = 4, + EXISTS = 5, + DELTA = 6, + MULTI = 7, + CHANGED = 8, + PREVIOUS_PRESENCE = 9 +} + +struct Delta { + 1: optional binary previous, + 2: optional binary current +} + +struct OptionalBinary { + 1: optional binary value +} + +struct UpdateBegin { + 1: optional binary previous, + 2: optional i64 updateId +} + +service RocksDB { + + i64 openTransaction(1: required i64 timeoutMs), + + bool closeTransaction(1: required i64 timeoutMs, 2: required bool commit), + + void closeFailedUpdate(1: required i64 updateId), + + i64 createColumn(1: required string name, 2: required ColumnSchema schema), + + void deleteColumn(1: required i64 columnId), + + i64 getColumnId(1: required string name), + + oneway void putFast(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list keys, 4: required binary value), + + void put(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list keys, 4: required binary value), + + void putMulti(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list> keysMulti, 4: required list valueMulti), + + OptionalBinary putGetPrevious(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list keys, 4: required binary value), + + Delta putGetDelta(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list keys, 4: required binary value), + + bool putGetChanged(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list keys, 4: required binary value), + + bool putGetPreviousPresence(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list keys, 4: required binary value), + + OptionalBinary get(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list keys), + + UpdateBegin getForUpdate(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list keys), + + bool exists(1: required i64 transactionOrUpdateId, 3: required i64 columnId, 4: required list keys), + + i64 openIterator(1: required i64 transactionId, 2: required i64 columnId, 3: required list startKeysInclusive, 4: list endKeysExclusive, 5: required bool reverse, 6: required i64 timeoutMs), + + void closeIterator(1: required i64 iteratorId), + + void seekTo(1: required i64 iterationId, 2: required list keys), + + void subsequent(1: required i64 iterationId, 2: required i64 skipCount, 3: required i64 takeCount), + + bool subsequentExists(1: required i64 iterationId, 2: required i64 skipCount, 3: required i64 takeCount), + + list subsequentMultiGet(1: required i64 iterationId, 2: required i64 skipCount, 3: required i64 takeCount), + +} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6cff944 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +Pyrogram~=1.4.8 +bson~=0.5.10 +thriftpy2~=0.4.20 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6e803db --- /dev/null +++ b/setup.py @@ -0,0 +1,44 @@ +import pathlib +import re +from setuptools import setup + +here = pathlib.Path(__file__).parent +init = here / "pyrogram_rockserver_storage" / "__init__.py" +readme_path = here / "README.md" + +with open("requirements.txt", encoding="utf-8") as r: + requires = [i.strip() for i in r] + +with init.open() as fp: + try: + version = re.findall(r"^__version__ = '([^']+)'$", fp.read(), re.M)[0] + except IndexError: + raise RuntimeError('Unable to determine version.') + + +with readme_path.open() as f: + README = f.read() + +setup( + name='pyrogram-rockserver-storage', + version=version, + description='rockserver storage for pyrogram', + long_description=README, + long_description_content_type='text/markdown', + author='Andrea Cavalli', + author_email='nospam@warp.ovh', + url='https://github.com/cavallium/pyrogram-rockserver-storage', + packages=["pyrogram_rockserver_storage", ], + classifiers=[ + "Operating System :: OS Independent", + 'Intended Audience :: Developers', + 'Programming Language :: Python', + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + ], + python_requires='>=3.6.0', + install_requires=requires +)