Add username and phone cache

This commit is contained in:
Andrea Cavalli 2024-03-30 22:12:23 +01:00
parent e4a6ce26c1
commit 1c39d50386
3 changed files with 60 additions and 10 deletions

19
main.py Normal file
View File

@ -0,0 +1,19 @@
from pyrogram_rockserver_storage import RockServerStorage
import asyncio
async def main():
await storage.open()
await storage.update_peers([(-1001946993950, 2, "channel", "4", "5"), (-6846993950, 2, "group", "4", "5")])
print("Peer channel", await storage.get_peer_by_id(-1001946993950))
print("Peer group", await storage.get_peer_by_id(-6846993950))
print("dc_id", await storage.dc_id())
await storage.dc_id(8)
print("dc_id", await storage.dc_id())
print("ciao")
storage = RockServerStorage(save_user_peers=False, hostname="127.0.0.1", port=5332, session_unique_name="test_99")
print("ciao2")
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
print("ciao3")

View File

@ -17,7 +17,7 @@ from thriftpy2.contrib.aio.socket import TAsyncSocket
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
from thriftpy2.transport.framed import TFramedTransportFactory
from lru import LRU
import thriftpy2
@ -33,9 +33,6 @@ PROTOCOL_FACTORY = TAsyncBinaryProtocolFactory()
rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift")
from thriftpy2.rpc import make_aio_client
class PeerType(Enum):
""" Pyrogram peer types """
USER = 'user'
@ -113,6 +110,9 @@ class RockServerStorage(Storage):
self._save_user_peers = save_user_peers
self._username_to_id = LRU(100_000)
self._phone_to_id = LRU(100_000)
super().__init__(name=self._session_id)
async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None,
@ -218,13 +218,22 @@ class RockServerStorage(Storage):
keys_multi = []
value_multi = []
for deduplicated_peer in deduplicated_peers:
keys = [deduplicated_peer[0].to_bytes(8, byteorder='big', signed=True)]
value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], deduplicated_peer[3],
deduplicated_peer[4], deduplicated_peer[5])
peer_id = deduplicated_peer[0]
username = deduplicated_peer[3]
phone_number = deduplicated_peer[4]
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], username,
phone_number, deduplicated_peer[5])
value = bson.dumps(value_tuple)
keys_multi.append(keys)
value_multi.append(value)
if username is not None:
self._username_to_id[username] = peer_id
if phone_number is not None:
self._phone_to_id[phone_number] = peer_id
await self._client.putMulti(0, self._peer_col, keys_multi, value_multi)
async def get_peer_by_id(self, peer_id: int):
@ -240,10 +249,31 @@ class RockServerStorage(Storage):
return get_input_peer(value_tuple)
async def get_peer_by_username(self, username: str):
raise KeyError("get_peer_by_username is not supported with rocksdb storage")
peer_id = self._username_to_id.get(username)
if peer_id is None:
raise KeyError(f"Username not found: {username}")
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
encoded_value = await fetchone(self._client, self._peer_col, keys)
value_tuple = decode_peer_info(peer_id, encoded_value)
if int(time.time() - value_tuple['last_update_on']) > self.USERNAME_TTL:
raise KeyError(f"Username expired: {username}")
return get_input_peer(value_tuple)
async def get_peer_by_phone_number(self, phone_number: str):
raise KeyError("get_peer_by_username is not supported with rocksdb storage")
peer_id = self._phone_to_id.get(phone_number)
if peer_id is None:
raise KeyError(f"Phone number not found: {phone_number}")
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
encoded_value = await fetchone(self._client, self._peer_col, keys)
value_tuple = decode_peer_info(peer_id, encoded_value)
return get_input_peer(value_tuple)
async def _set(self, column, value: Any):
update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY)

View File

@ -1,3 +1,4 @@
Pyrogram~=1.4.8
bson~=0.5.10
thriftpy2~=0.4.20
lru-dict~=1.3.0