initial commit
This commit is contained in:
commit
9563e40ba6
62
.gitignore
vendored
Normal file
62
.gitignore
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
# Created by .ignore support plugin (hsz.mobi)
|
||||
### JetBrains template
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
.idea
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# CMake
|
||||
cmake-build-debug/
|
||||
cmake-build-release/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
*.pyc
|
||||
dist
|
||||
*.egg-info
|
||||
MANIFEST
|
||||
|
||||
.venv
|
||||
|
||||
.vscode/
|
16
README.md
Normal file
16
README.md
Normal file
@ -0,0 +1,16 @@
|
||||
Pyrogram RockServer storage
|
||||
======================
|
||||
|
||||
Usage
|
||||
-----
|
||||
|
||||
```python
|
||||
|
||||
from pyrogram import Client
|
||||
from pyrogram_rockserver_storage import RockServerStorage
|
||||
|
||||
session = RockServerStorage(hostname=..., port=..., session_unique_name=..., save_user_peers=...)
|
||||
pyrogram = Client(session)
|
||||
await pyrogram.connect()
|
||||
|
||||
```
|
18
main.py
Normal file
18
main.py
Normal file
@ -0,0 +1,18 @@
|
||||
from pyrogram_rockserver_storage import RockServerStorage
|
||||
import asyncio
|
||||
|
||||
|
||||
async def main():
|
||||
await storage.open()
|
||||
await storage.update_peers([(-1001946993950, 2, "channel", "4", "5"), (-6846993950, 2, "group", "4", "5")])
|
||||
print("Peer channel", await storage.get_peer_by_id(-1001946993950))
|
||||
print("Peer group", await storage.get_peer_by_id(-6846993950))
|
||||
print("dc_id", await storage.dc_id())
|
||||
await storage.dc_id(4)
|
||||
|
||||
print("ciao")
|
||||
storage = RockServerStorage(save_user_peers=False, hostname="127.0.0.1", port=5332, session_unique_name="test_99")
|
||||
print("ciao2")
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
print("ciao3")
|
242
pyrogram_rockserver_storage/__init__.py
Normal file
242
pyrogram_rockserver_storage/__init__.py
Normal file
@ -0,0 +1,242 @@
|
||||
__author__ = 'Andrea Cavalli'
|
||||
__version__ = '0.1'
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
import time
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from string import digits
|
||||
from typing import Any, List, Tuple, Dict, Optional
|
||||
|
||||
from pyrogram import raw, utils
|
||||
from pyrogram.storage import Storage
|
||||
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory
|
||||
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
|
||||
|
||||
from thriftpy2.transport.framed import TFramedTransportFactory
|
||||
|
||||
import thriftpy2
|
||||
|
||||
import bson
|
||||
from thriftpy2.contrib.aio.client import TAsyncClient
|
||||
|
||||
SESSION_KEY = [bytes([0])]
|
||||
DIGITS = set(digits)
|
||||
TRANSPORT_FACTORY = TAsyncFramedTransportFactory()
|
||||
PROTOCOL_FACTORY = TAsyncBinaryProtocolFactory()
|
||||
|
||||
rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift")
|
||||
|
||||
from thriftpy2.rpc import make_aio_client
|
||||
|
||||
|
||||
class PeerType(Enum):
|
||||
""" Pyrogram peer types """
|
||||
USER = 'user'
|
||||
BOT = 'bot'
|
||||
GROUP = 'group'
|
||||
CHANNEL = 'channel'
|
||||
SUPERGROUP = 'supergroup'
|
||||
|
||||
def encode_peer_info(access_hash: int, peer_type: str, username: str, phone_number: str, last_update_on: int):
|
||||
return {"access_hash": access_hash, "peer_type": peer_type, "username": username, "phone_number": phone_number, "last_update_on": last_update_on}
|
||||
|
||||
def decode_peer_info(peer_id: int, value):
|
||||
return {"id": peer_id, "access_hash": value["access_hash"], "peer_type": value["peer_type"]} if value is not None else None
|
||||
|
||||
def get_input_peer(peer):
|
||||
""" This function is almost blindly copied from pyrogram sqlite storage"""
|
||||
peer_id, peer_type, access_hash = peer['id'], peer['peer_type'], peer['access_hash']
|
||||
|
||||
if peer_type in {PeerType.USER.value, PeerType.BOT.value}:
|
||||
return raw.types.InputPeerUser(user_id=peer_id, access_hash=access_hash)
|
||||
|
||||
if peer_type == PeerType.GROUP.value:
|
||||
return raw.types.InputPeerChat(chat_id=-peer_id)
|
||||
|
||||
if peer_type in {PeerType.CHANNEL.value, PeerType.SUPERGROUP.value}:
|
||||
return raw.types.InputPeerChannel(
|
||||
channel_id=utils.get_channel_id(peer_id),
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
raise ValueError(f"Invalid peer type: {peer['type']}")
|
||||
|
||||
|
||||
async def fetchone(client: TAsyncClient, column: int, keys: Any) -> Optional[Dict]:
|
||||
""" Small helper - fetches a single row from provided query """
|
||||
value = (await client.get(0, column, keys)).value
|
||||
value = bson.loads(value) if value else None
|
||||
return dict(value) if value else None
|
||||
|
||||
|
||||
class RockServerStorage(Storage):
|
||||
"""
|
||||
Implementation of RockServer storage.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> from pyrogram import Client
|
||||
>>>
|
||||
>>> session = RockServerStorage(hostname=..., port=5332, user_id=..., session_unique_name=..., save_user_peers=...)
|
||||
>>> pyrogram = Client(session_name=session)
|
||||
>>> await pyrogram.connect()
|
||||
>>> ...
|
||||
|
||||
"""
|
||||
|
||||
USERNAME_TTL = 8 * 60 * 60 # pyrogram constant
|
||||
|
||||
def __init__(self,
|
||||
hostname: str,
|
||||
port: int,
|
||||
session_unique_name: str,
|
||||
save_user_peers: bool):
|
||||
"""
|
||||
:param hostname: rocksdb hostname
|
||||
:param port: rocksdb port
|
||||
:param session_unique_name: telegram session phone
|
||||
"""
|
||||
self._session_col = None
|
||||
self._peer_col = None
|
||||
self._session_id = f'{session_unique_name}'
|
||||
self._session_data = None
|
||||
self._client = None
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
self._save_user_peers = save_user_peers
|
||||
|
||||
super().__init__(name=self._session_id)
|
||||
|
||||
async def open(self):
|
||||
""" Initialize pyrogram session"""
|
||||
self._client = await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000)
|
||||
|
||||
# Column('dc_id', BIGINT, primary_key=True),
|
||||
# Column('api_id', BIGINT),
|
||||
# Column('test_mode', Boolean),
|
||||
# Column('auth_key', BYTEA),
|
||||
# Column('date', BIGINT, nullable=False),
|
||||
# Column('user_id', BIGINT),
|
||||
# Column('is_bot', Boolean),
|
||||
# Column('phone', String(length=50)
|
||||
self._session_col = await self._client.createColumn(name=f'pyrogram_session_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[1], variableTailKeys=[], hasValue=True))
|
||||
|
||||
# Column('id', BIGINT),
|
||||
# Column('access_hash', BIGINT),
|
||||
# Column('type', String, nullable=False),
|
||||
# Column('username', String),
|
||||
# Column('phone_number', String),
|
||||
# Column('last_update_on', BIGINT),
|
||||
self._peer_col = await self._client.createColumn(name=f'peers_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[8], variableTailKeys=[], hasValue=True))
|
||||
|
||||
self._session_data = await fetchone(self._client, self._session_col, SESSION_KEY)
|
||||
if self._session_data is None:
|
||||
self._session_data = {"dc_id": 2, "api_id": None, "test_mode": None, "auth_key": None, "date": 0, "user_id": None, "is_bot": None, "phone": None}
|
||||
|
||||
async def save(self):
|
||||
""" On save we update the date """
|
||||
await self.date(int(time.time()))
|
||||
|
||||
async def close(self):
|
||||
""" Close transport """
|
||||
await self._client.close()
|
||||
|
||||
async def delete(self):
|
||||
""" Delete all the tables and indexes """
|
||||
await self._client.deleteColumn(self._session_id)
|
||||
await self._client.deleteColumn(self._peer_col)
|
||||
|
||||
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
|
||||
""" Copied and adopted from pyro sqlite storage"""
|
||||
if not peers:
|
||||
return
|
||||
|
||||
now = int(time.time())
|
||||
deduplicated_peers = []
|
||||
seen_ids = set()
|
||||
|
||||
# deduplicate peers to avoid possible `CardinalityViolation` error
|
||||
for peer in peers:
|
||||
if not self._save_user_peers and peer[2] == "user":
|
||||
continue
|
||||
peer_id, *_ = peer
|
||||
if peer_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(peer_id)
|
||||
# enrich peer with timestamp and append
|
||||
deduplicated_peers.append(tuple(chain(peer, (now,))))
|
||||
|
||||
# construct insert query
|
||||
if deduplicated_peers:
|
||||
keys_multi = []
|
||||
value_multi = []
|
||||
for deduplicated_peer in deduplicated_peers:
|
||||
keys = [deduplicated_peer[0].to_bytes(8, byteorder='big', signed=True)]
|
||||
value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], deduplicated_peer[3],
|
||||
deduplicated_peer[4], deduplicated_peer[5])
|
||||
value = bson.dumps(value_tuple)
|
||||
keys_multi.append(keys)
|
||||
value_multi.append(value)
|
||||
|
||||
await self._client.putMulti(0, self._peer_col, keys_multi, value_multi)
|
||||
|
||||
async def get_peer_by_id(self, peer_id: int):
|
||||
if isinstance(peer_id, str) or (not self._save_user_peers and peer_id > 0):
|
||||
raise KeyError(f"ID not found: {peer_id}")
|
||||
|
||||
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
|
||||
encoded_value = await fetchone(self._client, self._peer_col, keys)
|
||||
value_tuple = decode_peer_info(peer_id, encoded_value)
|
||||
if value_tuple is None:
|
||||
raise KeyError(f"ID not found: {peer_id}")
|
||||
|
||||
return get_input_peer(value_tuple)
|
||||
|
||||
async def get_peer_by_username(self, username: str):
|
||||
raise KeyError("get_peer_by_username is not supported with rocksdb storage")
|
||||
|
||||
async def get_peer_by_phone_number(self, phone_number: str):
|
||||
raise KeyError("get_peer_by_username is not supported with rocksdb storage")
|
||||
|
||||
async def _set(self, column, value: Any):
|
||||
update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY)
|
||||
try:
|
||||
session_data = bson.loads(update_begin.previous) if update_begin.previous is not None else self._session_data
|
||||
session_data[column] = value
|
||||
encoded_session_data = bson.dumps(session_data)
|
||||
await self._client.put(update_begin.updateId, self._session_col, SESSION_KEY, encoded_session_data)
|
||||
except:
|
||||
print("Failed to update session in rocksdb, cancelling the update transaction...")
|
||||
try:
|
||||
await self._client.closeFailedUpdate(update_begin.updateId)
|
||||
except:
|
||||
pass
|
||||
self._session_data[column] = value # update local copy
|
||||
|
||||
async def _accessor(self, column, value: Any = object):
|
||||
return self._session_data[column] if value == object else await self._set(column, value)
|
||||
|
||||
async def dc_id(self, value: int = object):
|
||||
return await self._accessor('dc_id', value)
|
||||
|
||||
async def api_id(self, value: int = object):
|
||||
return await self._accessor('api_id', value)
|
||||
|
||||
async def test_mode(self, value: bool = object):
|
||||
return await self._accessor('test_mode', value)
|
||||
|
||||
async def auth_key(self, value: bytes = object):
|
||||
return await self._accessor('auth_key', value)
|
||||
|
||||
async def date(self, value: int = object):
|
||||
return await self._accessor('date', value)
|
||||
|
||||
async def user_id(self, value: int = object):
|
||||
return await self._accessor('user_id', value)
|
||||
|
||||
async def is_bot(self, value: bool = object):
|
||||
return await self._accessor('is_bot', value)
|
||||
|
87
pyrogram_rockserver_storage/rocksdb.thrift
Normal file
87
pyrogram_rockserver_storage/rocksdb.thrift
Normal file
@ -0,0 +1,87 @@
|
||||
namespace java it.cavallium.rockserver.core.common.api
|
||||
|
||||
struct ColumnSchema {
|
||||
1: list<i32> fixedKeys,
|
||||
2: list<ColumnHashType> variableTailKeys,
|
||||
3: bool hasValue
|
||||
}
|
||||
|
||||
enum ColumnHashType {
|
||||
XXHASH32 = 1,
|
||||
XXHASH8 = 2,
|
||||
ALLSAME8 = 3
|
||||
}
|
||||
|
||||
enum Operation {
|
||||
NOTHING = 1,
|
||||
PREVIOUS = 2,
|
||||
CURRENT = 3,
|
||||
FOR_UPDATE = 4,
|
||||
EXISTS = 5,
|
||||
DELTA = 6,
|
||||
MULTI = 7,
|
||||
CHANGED = 8,
|
||||
PREVIOUS_PRESENCE = 9
|
||||
}
|
||||
|
||||
struct Delta {
|
||||
1: optional binary previous,
|
||||
2: optional binary current
|
||||
}
|
||||
|
||||
struct OptionalBinary {
|
||||
1: optional binary value
|
||||
}
|
||||
|
||||
struct UpdateBegin {
|
||||
1: optional binary previous,
|
||||
2: optional i64 updateId
|
||||
}
|
||||
|
||||
service RocksDB {
|
||||
|
||||
i64 openTransaction(1: required i64 timeoutMs),
|
||||
|
||||
bool closeTransaction(1: required i64 timeoutMs, 2: required bool commit),
|
||||
|
||||
void closeFailedUpdate(1: required i64 updateId),
|
||||
|
||||
i64 createColumn(1: required string name, 2: required ColumnSchema schema),
|
||||
|
||||
void deleteColumn(1: required i64 columnId),
|
||||
|
||||
i64 getColumnId(1: required string name),
|
||||
|
||||
oneway void putFast(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<binary> keys, 4: required binary value),
|
||||
|
||||
void put(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<binary> keys, 4: required binary value),
|
||||
|
||||
void putMulti(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<list<binary>> keysMulti, 4: required list<binary> valueMulti),
|
||||
|
||||
OptionalBinary putGetPrevious(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<binary> keys, 4: required binary value),
|
||||
|
||||
Delta putGetDelta(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<binary> keys, 4: required binary value),
|
||||
|
||||
bool putGetChanged(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<binary> keys, 4: required binary value),
|
||||
|
||||
bool putGetPreviousPresence(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<binary> keys, 4: required binary value),
|
||||
|
||||
OptionalBinary get(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<binary> keys),
|
||||
|
||||
UpdateBegin getForUpdate(1: required i64 transactionOrUpdateId, 2: required i64 columnId, 3: required list<binary> keys),
|
||||
|
||||
bool exists(1: required i64 transactionOrUpdateId, 3: required i64 columnId, 4: required list<binary> keys),
|
||||
|
||||
i64 openIterator(1: required i64 transactionId, 2: required i64 columnId, 3: required list<binary> startKeysInclusive, 4: list<binary> endKeysExclusive, 5: required bool reverse, 6: required i64 timeoutMs),
|
||||
|
||||
void closeIterator(1: required i64 iteratorId),
|
||||
|
||||
void seekTo(1: required i64 iterationId, 2: required list<binary> keys),
|
||||
|
||||
void subsequent(1: required i64 iterationId, 2: required i64 skipCount, 3: required i64 takeCount),
|
||||
|
||||
bool subsequentExists(1: required i64 iterationId, 2: required i64 skipCount, 3: required i64 takeCount),
|
||||
|
||||
list<OptionalBinary> subsequentMultiGet(1: required i64 iterationId, 2: required i64 skipCount, 3: required i64 takeCount),
|
||||
|
||||
}
|
3
requirements.txt
Normal file
3
requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
Pyrogram~=1.4.8
|
||||
bson~=0.5.10
|
||||
thriftpy2~=0.4.20
|
44
setup.py
Normal file
44
setup.py
Normal file
@ -0,0 +1,44 @@
|
||||
import pathlib
|
||||
import re
|
||||
from setuptools import setup
|
||||
|
||||
here = pathlib.Path(__file__).parent
|
||||
init = here / "pyrogram_rockserver_storage" / "__init__.py"
|
||||
readme_path = here / "README.md"
|
||||
|
||||
with open("requirements.txt", encoding="utf-8") as r:
|
||||
requires = [i.strip() for i in r]
|
||||
|
||||
with init.open() as fp:
|
||||
try:
|
||||
version = re.findall(r"^__version__ = '([^']+)'$", fp.read(), re.M)[0]
|
||||
except IndexError:
|
||||
raise RuntimeError('Unable to determine version.')
|
||||
|
||||
|
||||
with readme_path.open() as f:
|
||||
README = f.read()
|
||||
|
||||
setup(
|
||||
name='pyrogram-rockserver-storage',
|
||||
version=version,
|
||||
description='rockserver storage for pyrogram',
|
||||
long_description=README,
|
||||
long_description_content_type='text/markdown',
|
||||
author='Andrea Cavalli',
|
||||
author_email='nospam@warp.ovh',
|
||||
url='https://github.com/cavallium/pyrogram-rockserver-storage',
|
||||
packages=["pyrogram_rockserver_storage", ],
|
||||
classifiers=[
|
||||
"Operating System :: OS Independent",
|
||||
'Intended Audience :: Developers',
|
||||
'Programming Language :: Python',
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
],
|
||||
python_requires='>=3.6.0',
|
||||
install_requires=requires
|
||||
)
|
Loading…
Reference in New Issue
Block a user