Add username and phone cache

This commit is contained in:
Andrea Cavalli 2024-03-30 22:12:23 +01:00
parent e4a6ce26c1
commit 1c39d50386
3 changed files with 60 additions and 10 deletions

19
main.py Normal file
View File

@ -0,0 +1,19 @@
from pyrogram_rockserver_storage import RockServerStorage
import asyncio
async def main():
await storage.open()
await storage.update_peers([(-1001946993950, 2, "channel", "4", "5"), (-6846993950, 2, "group", "4", "5")])
print("Peer channel", await storage.get_peer_by_id(-1001946993950))
print("Peer group", await storage.get_peer_by_id(-6846993950))
print("dc_id", await storage.dc_id())
await storage.dc_id(8)
print("dc_id", await storage.dc_id())
print("ciao")
storage = RockServerStorage(save_user_peers=False, hostname="127.0.0.1", port=5332, session_unique_name="test_99")
print("ciao2")
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
print("ciao3")

View File

@ -17,7 +17,7 @@ from thriftpy2.contrib.aio.socket import TAsyncSocket
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
from thriftpy2.transport.framed import TFramedTransportFactory from lru import LRU
import thriftpy2 import thriftpy2
@ -33,9 +33,6 @@ PROTOCOL_FACTORY = TAsyncBinaryProtocolFactory()
rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift") rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift")
from thriftpy2.rpc import make_aio_client
class PeerType(Enum): class PeerType(Enum):
""" Pyrogram peer types """ """ Pyrogram peer types """
USER = 'user' USER = 'user'
@ -113,6 +110,9 @@ class RockServerStorage(Storage):
self._save_user_peers = save_user_peers self._save_user_peers = save_user_peers
self._username_to_id = LRU(100_000)
self._phone_to_id = LRU(100_000)
super().__init__(name=self._session_id) super().__init__(name=self._session_id)
async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None, async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None,
@ -218,13 +218,22 @@ class RockServerStorage(Storage):
keys_multi = [] keys_multi = []
value_multi = [] value_multi = []
for deduplicated_peer in deduplicated_peers: for deduplicated_peer in deduplicated_peers:
keys = [deduplicated_peer[0].to_bytes(8, byteorder='big', signed=True)] peer_id = deduplicated_peer[0]
value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], deduplicated_peer[3], username = deduplicated_peer[3]
deduplicated_peer[4], deduplicated_peer[5]) phone_number = deduplicated_peer[4]
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], username,
phone_number, deduplicated_peer[5])
value = bson.dumps(value_tuple) value = bson.dumps(value_tuple)
keys_multi.append(keys) keys_multi.append(keys)
value_multi.append(value) value_multi.append(value)
if username is not None:
self._username_to_id[username] = peer_id
if phone_number is not None:
self._phone_to_id[phone_number] = peer_id
await self._client.putMulti(0, self._peer_col, keys_multi, value_multi) await self._client.putMulti(0, self._peer_col, keys_multi, value_multi)
async def get_peer_by_id(self, peer_id: int): async def get_peer_by_id(self, peer_id: int):
@ -240,10 +249,31 @@ class RockServerStorage(Storage):
return get_input_peer(value_tuple) return get_input_peer(value_tuple)
async def get_peer_by_username(self, username: str): async def get_peer_by_username(self, username: str):
raise KeyError("get_peer_by_username is not supported with rocksdb storage") peer_id = self._username_to_id.get(username)
if peer_id is None:
raise KeyError(f"Username not found: {username}")
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
encoded_value = await fetchone(self._client, self._peer_col, keys)
value_tuple = decode_peer_info(peer_id, encoded_value)
if int(time.time() - value_tuple['last_update_on']) > self.USERNAME_TTL:
raise KeyError(f"Username expired: {username}")
return get_input_peer(value_tuple)
async def get_peer_by_phone_number(self, phone_number: str): async def get_peer_by_phone_number(self, phone_number: str):
raise KeyError("get_peer_by_username is not supported with rocksdb storage") peer_id = self._phone_to_id.get(phone_number)
if peer_id is None:
raise KeyError(f"Phone number not found: {phone_number}")
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
encoded_value = await fetchone(self._client, self._peer_col, keys)
value_tuple = decode_peer_info(peer_id, encoded_value)
return get_input_peer(value_tuple)
async def _set(self, column, value: Any): async def _set(self, column, value: Any):
update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY) update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY)

View File

@ -1,3 +1,4 @@
Pyrogram~=1.4.8 Pyrogram~=1.4.8
bson~=0.5.10 bson~=0.5.10
thriftpy2~=0.4.20 thriftpy2~=0.4.20
lru-dict~=1.3.0