Add username and phone cache
This commit is contained in:
parent
e4a6ce26c1
commit
1c39d50386
19
main.py
Normal file
19
main.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from pyrogram_rockserver_storage import RockServerStorage
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
await storage.open()
|
||||||
|
await storage.update_peers([(-1001946993950, 2, "channel", "4", "5"), (-6846993950, 2, "group", "4", "5")])
|
||||||
|
print("Peer channel", await storage.get_peer_by_id(-1001946993950))
|
||||||
|
print("Peer group", await storage.get_peer_by_id(-6846993950))
|
||||||
|
print("dc_id", await storage.dc_id())
|
||||||
|
await storage.dc_id(8)
|
||||||
|
print("dc_id", await storage.dc_id())
|
||||||
|
|
||||||
|
print("ciao")
|
||||||
|
storage = RockServerStorage(save_user_peers=False, hostname="127.0.0.1", port=5332, session_unique_name="test_99")
|
||||||
|
print("ciao2")
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
print("ciao3")
|
@ -17,7 +17,7 @@ from thriftpy2.contrib.aio.socket import TAsyncSocket
|
|||||||
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory
|
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory
|
||||||
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
|
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
|
||||||
|
|
||||||
from thriftpy2.transport.framed import TFramedTransportFactory
|
from lru import LRU
|
||||||
|
|
||||||
import thriftpy2
|
import thriftpy2
|
||||||
|
|
||||||
@ -33,9 +33,6 @@ PROTOCOL_FACTORY = TAsyncBinaryProtocolFactory()
|
|||||||
|
|
||||||
rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift")
|
rocksdb_thrift = thriftpy2.load(str((pathlib.Path(__file__).parent / pathlib.Path("rocksdb.thrift")).resolve(strict=False)), module_name="rocksdb_thrift")
|
||||||
|
|
||||||
from thriftpy2.rpc import make_aio_client
|
|
||||||
|
|
||||||
|
|
||||||
class PeerType(Enum):
|
class PeerType(Enum):
|
||||||
""" Pyrogram peer types """
|
""" Pyrogram peer types """
|
||||||
USER = 'user'
|
USER = 'user'
|
||||||
@ -113,6 +110,9 @@ class RockServerStorage(Storage):
|
|||||||
|
|
||||||
self._save_user_peers = save_user_peers
|
self._save_user_peers = save_user_peers
|
||||||
|
|
||||||
|
self._username_to_id = LRU(100_000)
|
||||||
|
self._phone_to_id = LRU(100_000)
|
||||||
|
|
||||||
super().__init__(name=self._session_id)
|
super().__init__(name=self._session_id)
|
||||||
|
|
||||||
async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None,
|
async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None,
|
||||||
@ -218,13 +218,22 @@ class RockServerStorage(Storage):
|
|||||||
keys_multi = []
|
keys_multi = []
|
||||||
value_multi = []
|
value_multi = []
|
||||||
for deduplicated_peer in deduplicated_peers:
|
for deduplicated_peer in deduplicated_peers:
|
||||||
keys = [deduplicated_peer[0].to_bytes(8, byteorder='big', signed=True)]
|
peer_id = deduplicated_peer[0]
|
||||||
value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], deduplicated_peer[3],
|
username = deduplicated_peer[3]
|
||||||
deduplicated_peer[4], deduplicated_peer[5])
|
phone_number = deduplicated_peer[4]
|
||||||
|
|
||||||
|
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
|
||||||
|
value_tuple = encode_peer_info(deduplicated_peer[1], deduplicated_peer[2], username,
|
||||||
|
phone_number, deduplicated_peer[5])
|
||||||
value = bson.dumps(value_tuple)
|
value = bson.dumps(value_tuple)
|
||||||
keys_multi.append(keys)
|
keys_multi.append(keys)
|
||||||
value_multi.append(value)
|
value_multi.append(value)
|
||||||
|
|
||||||
|
if username is not None:
|
||||||
|
self._username_to_id[username] = peer_id
|
||||||
|
if phone_number is not None:
|
||||||
|
self._phone_to_id[phone_number] = peer_id
|
||||||
|
|
||||||
await self._client.putMulti(0, self._peer_col, keys_multi, value_multi)
|
await self._client.putMulti(0, self._peer_col, keys_multi, value_multi)
|
||||||
|
|
||||||
async def get_peer_by_id(self, peer_id: int):
|
async def get_peer_by_id(self, peer_id: int):
|
||||||
@ -240,10 +249,31 @@ class RockServerStorage(Storage):
|
|||||||
return get_input_peer(value_tuple)
|
return get_input_peer(value_tuple)
|
||||||
|
|
||||||
async def get_peer_by_username(self, username: str):
|
async def get_peer_by_username(self, username: str):
|
||||||
raise KeyError("get_peer_by_username is not supported with rocksdb storage")
|
peer_id = self._username_to_id.get(username)
|
||||||
|
|
||||||
|
if peer_id is None:
|
||||||
|
raise KeyError(f"Username not found: {username}")
|
||||||
|
|
||||||
|
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
|
||||||
|
encoded_value = await fetchone(self._client, self._peer_col, keys)
|
||||||
|
value_tuple = decode_peer_info(peer_id, encoded_value)
|
||||||
|
|
||||||
|
if int(time.time() - value_tuple['last_update_on']) > self.USERNAME_TTL:
|
||||||
|
raise KeyError(f"Username expired: {username}")
|
||||||
|
|
||||||
|
return get_input_peer(value_tuple)
|
||||||
|
|
||||||
async def get_peer_by_phone_number(self, phone_number: str):
|
async def get_peer_by_phone_number(self, phone_number: str):
|
||||||
raise KeyError("get_peer_by_username is not supported with rocksdb storage")
|
peer_id = self._phone_to_id.get(phone_number)
|
||||||
|
|
||||||
|
if peer_id is None:
|
||||||
|
raise KeyError(f"Phone number not found: {phone_number}")
|
||||||
|
|
||||||
|
keys = [peer_id.to_bytes(8, byteorder='big', signed=True)]
|
||||||
|
encoded_value = await fetchone(self._client, self._peer_col, keys)
|
||||||
|
value_tuple = decode_peer_info(peer_id, encoded_value)
|
||||||
|
|
||||||
|
return get_input_peer(value_tuple)
|
||||||
|
|
||||||
async def _set(self, column, value: Any):
|
async def _set(self, column, value: Any):
|
||||||
update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY)
|
update_begin = await self._client.getForUpdate(0, self._session_col, SESSION_KEY)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
Pyrogram~=1.4.8
|
Pyrogram~=1.4.8
|
||||||
bson~=0.5.10
|
bson~=0.5.10
|
||||||
thriftpy2~=0.4.20
|
thriftpy2~=0.4.20
|
||||||
|
lru-dict~=1.3.0
|
||||||
|
Loading…
Reference in New Issue
Block a user