Add parallel async client

This commit is contained in:
Andrea Cavalli 2024-03-30 19:48:25 +01:00
parent 938fa7e356
commit f7f0fd99d2
2 changed files with 108 additions and 1 deletions

View File

@ -0,0 +1,105 @@
import asyncio
import logging
from thriftpy2.contrib.aio.client import TAsyncClient
from thriftpy2.thrift import TApplicationException, TMessageType, args_to_kwargs
log = logging.getLogger(__name__)
class TParallelAsyncClient(TAsyncClient):
def __init__(self, service, iprot, oprot=None):
super().__init__(service, iprot, oprot)
self._open_requests: dict[int, asyncio.Future] = {}
self._message_processor: asyncio.Task | None = None
async def _req(self, _api, *args, **kwargs):
try:
service_args = getattr(self._service, _api + "_args")
kwargs = args_to_kwargs(service_args.thrift_spec, *args, **kwargs)
except ValueError as e:
raise TApplicationException(
TApplicationException.UNKNOWN_METHOD,
"missing required argument {arg} for {service}.{api}".format(
arg=e.args[0], service=self._service.__name__, api=_api
),
)
fut = await self._send(_api, **kwargs)
if fut is not None:
self._process_messages()
return await fut
async def _send(self, _api, **kwargs) -> asyncio.Future | None:
oneway = getattr(getattr(self._service, _api + "_result"), "oneway")
msg_type = TMessageType.ONEWAY if oneway else TMessageType.CALL
seq_id = self._get_seqid()
self._oprot.write_message_begin(_api, msg_type, seq_id)
args = getattr(self._service, _api + "_args")()
for k, v in kwargs.items():
setattr(args, k, v)
self._oprot.write_struct(args)
self._oprot.write_message_end()
await self._oprot.trans.flush()
log.debug("Sent seqid %d: %s", seq_id, _api)
if oneway:
return None
else:
self._open_requests[seq_id] = asyncio.Future()
return self._open_requests[seq_id]
def _process_messages(self):
if self._message_processor is None or self._message_processor.done():
self._message_processor = asyncio.create_task(self._message_handler())
async def _message_handler(self):
while self._open_requests:
fname, mtype, rseqid = await self._iprot.read_message_begin()
log.debug("Reply for seqid %d: %s %s", rseqid, fname, mtype)
fut = self._open_requests.pop(rseqid, None)
if fut is None:
log.error("Received message with unknown seqid %d", rseqid)
try:
fut.set_result(await self._process_message(fname, mtype))
except Exception as e:
fut.set_exception(e)
async def _process_message(self, fname, mtype):
"""process a single message"""
if mtype == TMessageType.EXCEPTION:
x = TApplicationException()
await self._iprot.read_struct(x)
await self._iprot.read_message_end()
raise x
result = getattr(self._service, fname + "_result")()
await self._iprot.read_struct(result)
await self._iprot.read_message_end()
if hasattr(result, "success") and result.success is not None:
return result.success
# void api without throws
if len(result.thrift_spec) == 0:
return
# check throws
for k, v in result.__dict__.items():
if k != "success" and v:
raise v
# no throws & not void api
if hasattr(result, "success"):
raise TApplicationException(TApplicationException.MISSING_RESULT)
def _get_seqid(self) -> int:
seq_id = self._seqid
self._seqid += 1
return seq_id
def close(self):
if self._message_processor is not None and not self._message_processor.done():
self._message_processor.cancel()
self._message_processor = None
super().close()

View File

@ -21,6 +21,8 @@ import thriftpy2
import bson
from thriftpy2.contrib.aio.client import TAsyncClient
from pyrogram_rockserver_storage.TParallelAsyncClient import TParallelAsyncClient
SESSION_KEY = [bytes([0])]
DIGITS = set(digits)
TRANSPORT_FACTORY = TAsyncFramedTransportFactory()
@ -112,7 +114,7 @@ class RockServerStorage(Storage):
async def open(self):
""" Initialize pyrogram session"""
self._client = await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000)
self._client = TParallelAsyncClient(await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000))
# Column('dc_id', BIGINT, primary_key=True),
# Column('api_id', BIGINT),