diff --git a/main.py b/main.py new file mode 100644 index 0000000..5399015 --- /dev/null +++ b/main.py @@ -0,0 +1,19 @@ +from pyrogram_rockserver_storage import RockServerStorage +import asyncio + + +async def main(): + await storage.open() + await storage.update_peers([(-1001946993950, 2, "channel", "4", "5"), (-6846993950, 2, "group", "4", "5")]) + print("Peer channel", await storage.get_peer_by_id(-1001946993950)) + print("Peer group", await storage.get_peer_by_id(-6846993950)) + print("dc_id", await storage.dc_id()) + await storage.dc_id(8) + print("dc_id", await storage.dc_id()) + +print("ciao") +storage = RockServerStorage(save_user_peers=False, hostname="127.0.0.1", port=5332, session_unique_name="test_99") +print("ciao2") +loop = asyncio.get_event_loop() +loop.run_until_complete(main()) +print("ciao3") diff --git a/pyrogram_rockserver_storage/__init__.py b/pyrogram_rockserver_storage/__init__.py index f39bb00..c51f3cb 100644 --- a/pyrogram_rockserver_storage/__init__.py +++ b/pyrogram_rockserver_storage/__init__.py @@ -17,7 +17,7 @@ from thriftpy2.contrib.aio.socket import TAsyncSocket from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory -from thriftpy2.transport.framed import TFramedTransportFactory +from lru import LRU import thriftpy2 @@ -33,9 +33,6 @@ PROTOCOL_FACTORY = TAsyncBinaryProtocolFactory() rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift") -from thriftpy2.rpc import make_aio_client - - class PeerType(Enum): """ Pyrogram peer types """ USER = 'user' @@ -113,6 +110,9 @@ class RockServerStorage(Storage): self._save_user_peers = save_user_peers + self._username_to_id = LRU(100_000) + self._phone_to_id = LRU(100_000) + super().__init__(name=self._session_id) async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None, @@ -218,13 +218,22 @@ class RockServerStorage(Storage): keys_multi = [] value_multi = [] for deduplicated_peer in deduplicated_peers: - keys = [deduplicated_peer[0].to_bytes(8, byteorder='big', signed=True)] - value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], deduplicated_peer[3], - deduplicated_peer[4], deduplicated_peer[5]) + peer_id = deduplicated_peer[0] + username = deduplicated_peer[3] + phone_number = deduplicated_peer[4] + + keys = [peer_id.to_bytes(8, byteorder='big', signed=True)] + value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], username, + phone_number, deduplicated_peer[5]) value = bson.dumps(value_tuple) keys_multi.append(keys) value_multi.append(value) + if username is not None: + self._username_to_id[username] = peer_id + if phone_number is not None: + self._phone_to_id[phone_number] = peer_id + await self._client.putMulti(0, self._peer_col, keys_multi, value_multi) async def get_peer_by_id(self, peer_id: int): @@ -240,10 +249,31 @@ class RockServerStorage(Storage): return get_input_peer(value_tuple) async def get_peer_by_username(self, username: str): - raise KeyError("get_peer_by_username is not supported with rocksdb storage") + peer_id = self._username_to_id.get(username) + + if peer_id is None: + raise KeyError(f"Username not found: {username}") + + keys = [peer_id.to_bytes(8, byteorder='big', signed=True)] + encoded_value = await fetchone(self._client, self._peer_col, keys) + value_tuple = decode_peer_info(peer_id, encoded_value) + + if int(time.time() - value_tuple['last_update_on']) > self.USERNAME_TTL: + raise KeyError(f"Username expired: {username}") + + return get_input_peer(value_tuple) async def get_peer_by_phone_number(self, phone_number: str): - raise KeyError("get_peer_by_username is not supported with rocksdb storage") + peer_id = self._phone_to_id.get(phone_number) + + if peer_id is None: + raise KeyError(f"Phone number not found: {phone_number}") + + keys = [peer_id.to_bytes(8, byteorder='big', signed=True)] + encoded_value = await fetchone(self._client, self._peer_col, keys) + value_tuple = decode_peer_info(peer_id, encoded_value) + + return get_input_peer(value_tuple) async def _set(self, column, value: Any): update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY) diff --git a/requirements.txt b/requirements.txt index 6cff944..fcafc4b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ Pyrogram~=1.4.8 bson~=0.5.10 -thriftpy2~=0.4.20 \ No newline at end of file +thriftpy2~=0.4.20 +lru-dict~=1.3.0