Add parallel async client
This commit is contained in:
parent
938fa7e356
commit
f7f0fd99d2
105
pyrogram_rockserver_storage/TParallelAsyncClient.py
Normal file
105
pyrogram_rockserver_storage/TParallelAsyncClient.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from thriftpy2.contrib.aio.client import TAsyncClient
|
||||||
|
from thriftpy2.thrift import TApplicationException, TMessageType, args_to_kwargs
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TParallelAsyncClient(TAsyncClient):
|
||||||
|
def __init__(self, service, iprot, oprot=None):
|
||||||
|
super().__init__(service, iprot, oprot)
|
||||||
|
self._open_requests: dict[int, asyncio.Future] = {}
|
||||||
|
self._message_processor: asyncio.Task | None = None
|
||||||
|
|
||||||
|
async def _req(self, _api, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
service_args = getattr(self._service, _api + "_args")
|
||||||
|
kwargs = args_to_kwargs(service_args.thrift_spec, *args, **kwargs)
|
||||||
|
except ValueError as e:
|
||||||
|
raise TApplicationException(
|
||||||
|
TApplicationException.UNKNOWN_METHOD,
|
||||||
|
"missing required argument {arg} for {service}.{api}".format(
|
||||||
|
arg=e.args[0], service=self._service.__name__, api=_api
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
fut = await self._send(_api, **kwargs)
|
||||||
|
if fut is not None:
|
||||||
|
self._process_messages()
|
||||||
|
return await fut
|
||||||
|
|
||||||
|
async def _send(self, _api, **kwargs) -> asyncio.Future | None:
|
||||||
|
oneway = getattr(getattr(self._service, _api + "_result"), "oneway")
|
||||||
|
msg_type = TMessageType.ONEWAY if oneway else TMessageType.CALL
|
||||||
|
seq_id = self._get_seqid()
|
||||||
|
self._oprot.write_message_begin(_api, msg_type, seq_id)
|
||||||
|
args = getattr(self._service, _api + "_args")()
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(args, k, v)
|
||||||
|
self._oprot.write_struct(args)
|
||||||
|
self._oprot.write_message_end()
|
||||||
|
await self._oprot.trans.flush()
|
||||||
|
log.debug("Sent seqid %d: %s", seq_id, _api)
|
||||||
|
|
||||||
|
if oneway:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
self._open_requests[seq_id] = asyncio.Future()
|
||||||
|
return self._open_requests[seq_id]
|
||||||
|
|
||||||
|
def _process_messages(self):
|
||||||
|
if self._message_processor is None or self._message_processor.done():
|
||||||
|
self._message_processor = asyncio.create_task(self._message_handler())
|
||||||
|
|
||||||
|
async def _message_handler(self):
|
||||||
|
while self._open_requests:
|
||||||
|
fname, mtype, rseqid = await self._iprot.read_message_begin()
|
||||||
|
log.debug("Reply for seqid %d: %s %s", rseqid, fname, mtype)
|
||||||
|
fut = self._open_requests.pop(rseqid, None)
|
||||||
|
if fut is None:
|
||||||
|
log.error("Received message with unknown seqid %d", rseqid)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fut.set_result(await self._process_message(fname, mtype))
|
||||||
|
except Exception as e:
|
||||||
|
fut.set_exception(e)
|
||||||
|
|
||||||
|
async def _process_message(self, fname, mtype):
|
||||||
|
"""process a single message"""
|
||||||
|
if mtype == TMessageType.EXCEPTION:
|
||||||
|
x = TApplicationException()
|
||||||
|
await self._iprot.read_struct(x)
|
||||||
|
await self._iprot.read_message_end()
|
||||||
|
raise x
|
||||||
|
result = getattr(self._service, fname + "_result")()
|
||||||
|
await self._iprot.read_struct(result)
|
||||||
|
await self._iprot.read_message_end()
|
||||||
|
|
||||||
|
if hasattr(result, "success") and result.success is not None:
|
||||||
|
return result.success
|
||||||
|
|
||||||
|
# void api without throws
|
||||||
|
if len(result.thrift_spec) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# check throws
|
||||||
|
for k, v in result.__dict__.items():
|
||||||
|
if k != "success" and v:
|
||||||
|
raise v
|
||||||
|
|
||||||
|
# no throws & not void api
|
||||||
|
if hasattr(result, "success"):
|
||||||
|
raise TApplicationException(TApplicationException.MISSING_RESULT)
|
||||||
|
|
||||||
|
def _get_seqid(self) -> int:
|
||||||
|
seq_id = self._seqid
|
||||||
|
self._seqid += 1
|
||||||
|
return seq_id
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._message_processor is not None and not self._message_processor.done():
|
||||||
|
self._message_processor.cancel()
|
||||||
|
self._message_processor = None
|
||||||
|
super().close()
|
@ -21,6 +21,8 @@ import thriftpy2
|
|||||||
import bson
|
import bson
|
||||||
from thriftpy2.contrib.aio.client import TAsyncClient
|
from thriftpy2.contrib.aio.client import TAsyncClient
|
||||||
|
|
||||||
|
from pyrogram_rockserver_storage.TParallelAsyncClient import TParallelAsyncClient
|
||||||
|
|
||||||
SESSION_KEY = [bytes([0])]
|
SESSION_KEY = [bytes([0])]
|
||||||
DIGITS = set(digits)
|
DIGITS = set(digits)
|
||||||
TRANSPORT_FACTORY = TAsyncFramedTransportFactory()
|
TRANSPORT_FACTORY = TAsyncFramedTransportFactory()
|
||||||
@ -112,7 +114,7 @@ class RockServerStorage(Storage):
|
|||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
""" Initialize pyrogram session"""
|
""" Initialize pyrogram session"""
|
||||||
self._client = await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000)
|
self._client = TParallelAsyncClient(await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000))
|
||||||
|
|
||||||
# Column('dc_id', BIGINT, primary_key=True),
|
# Column('dc_id', BIGINT, primary_key=True),
|
||||||
# Column('api_id', BIGINT),
|
# Column('api_id', BIGINT),
|
||||||
|
Loading…
Reference in New Issue
Block a user