From f7f0fd99d28883be84c33f82d88728cf7694911e Mon Sep 17 00:00:00 2001 From: Andrea Cavalli Date: Sat, 30 Mar 2024 19:48:25 +0100 Subject: [PATCH] Add parallel async client --- .../TParallelAsyncClient.py | 105 ++++++++++++++++++ pyrogram_rockserver_storage/__init__.py | 4 +- 2 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 pyrogram_rockserver_storage/TParallelAsyncClient.py diff --git a/pyrogram_rockserver_storage/TParallelAsyncClient.py b/pyrogram_rockserver_storage/TParallelAsyncClient.py new file mode 100644 index 0000000..67ad9db --- /dev/null +++ b/pyrogram_rockserver_storage/TParallelAsyncClient.py @@ -0,0 +1,105 @@ +import asyncio +import logging + +from thriftpy2.contrib.aio.client import TAsyncClient +from thriftpy2.thrift import TApplicationException, TMessageType, args_to_kwargs + +log = logging.getLogger(__name__) + + +class TParallelAsyncClient(TAsyncClient): + def __init__(self, service, iprot, oprot=None): + super().__init__(service, iprot, oprot) + self._open_requests: dict[int, asyncio.Future] = {} + self._message_processor: asyncio.Task | None = None + + async def _req(self, _api, *args, **kwargs): + try: + service_args = getattr(self._service, _api + "_args") + kwargs = args_to_kwargs(service_args.thrift_spec, *args, **kwargs) + except ValueError as e: + raise TApplicationException( + TApplicationException.UNKNOWN_METHOD, + "missing required argument {arg} for {service}.{api}".format( + arg=e.args[0], service=self._service.__name__, api=_api + ), + ) + + fut = await self._send(_api, **kwargs) + if fut is not None: + self._process_messages() + return await fut + + async def _send(self, _api, **kwargs) -> asyncio.Future | None: + oneway = getattr(getattr(self._service, _api + "_result"), "oneway") + msg_type = TMessageType.ONEWAY if oneway else TMessageType.CALL + seq_id = self._get_seqid() + self._oprot.write_message_begin(_api, msg_type, seq_id) + args = getattr(self._service, _api + "_args")() + for k, v in kwargs.items(): + setattr(args, k, v) + self._oprot.write_struct(args) + self._oprot.write_message_end() + await self._oprot.trans.flush() + log.debug("Sent seqid %d: %s", seq_id, _api) + + if oneway: + return None + else: + self._open_requests[seq_id] = asyncio.Future() + return self._open_requests[seq_id] + + def _process_messages(self): + if self._message_processor is None or self._message_processor.done(): + self._message_processor = asyncio.create_task(self._message_handler()) + + async def _message_handler(self): + while self._open_requests: + fname, mtype, rseqid = await self._iprot.read_message_begin() + log.debug("Reply for seqid %d: %s %s", rseqid, fname, mtype) + fut = self._open_requests.pop(rseqid, None) + if fut is None: + log.error("Received message with unknown seqid %d", rseqid) + + try: + fut.set_result(await self._process_message(fname, mtype)) + except Exception as e: + fut.set_exception(e) + + async def _process_message(self, fname, mtype): + """process a single message""" + if mtype == TMessageType.EXCEPTION: + x = TApplicationException() + await self._iprot.read_struct(x) + await self._iprot.read_message_end() + raise x + result = getattr(self._service, fname + "_result")() + await self._iprot.read_struct(result) + await self._iprot.read_message_end() + + if hasattr(result, "success") and result.success is not None: + return result.success + + # void api without throws + if len(result.thrift_spec) == 0: + return + + # check throws + for k, v in result.__dict__.items(): + if k != "success" and v: + raise v + + # no throws & not void api + if hasattr(result, "success"): + raise TApplicationException(TApplicationException.MISSING_RESULT) + + def _get_seqid(self) -> int: + seq_id = self._seqid + self._seqid += 1 + return seq_id + + def close(self): + if self._message_processor is not None and not self._message_processor.done(): + self._message_processor.cancel() + self._message_processor = None + super().close() diff --git a/pyrogram_rockserver_storage/__init__.py b/pyrogram_rockserver_storage/__init__.py index 5700cfb..549aeca 100644 --- a/pyrogram_rockserver_storage/__init__.py +++ b/pyrogram_rockserver_storage/__init__.py @@ -21,6 +21,8 @@ import thriftpy2 import bson from thriftpy2.contrib.aio.client import TAsyncClient +from pyrogram_rockserver_storage.TParallelAsyncClient import TParallelAsyncClient + SESSION_KEY = [bytes([0])] DIGITS = set(digits) TRANSPORT_FACTORY = TAsyncFramedTransportFactory() @@ -112,7 +114,7 @@ class RockServerStorage(Storage): async def open(self): """ Initialize pyrogram session""" - self._client = await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000) + self._client = TParallelAsyncClient(await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000)) # Column('dc_id', BIGINT, primary_key=True), # Column('api_id', BIGINT),