106 lines
3.8 KiB
Python
106 lines
3.8 KiB
Python
import asyncio
|
|
import logging
|
|
|
|
from thriftpy2.contrib.aio.client import TAsyncClient
|
|
from thriftpy2.thrift import TApplicationException, TMessageType, args_to_kwargs
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class TParallelAsyncClient(TAsyncClient):
|
|
def __init__(self, service, iprot, oprot=None):
|
|
super().__init__(service, iprot, oprot)
|
|
self._open_requests: dict[int, asyncio.Future] = {}
|
|
self._message_processor: asyncio.Task | None = None
|
|
|
|
async def _req(self, _api, *args, **kwargs):
|
|
try:
|
|
service_args = getattr(self._service, _api + "_args")
|
|
kwargs = args_to_kwargs(service_args.thrift_spec, *args, **kwargs)
|
|
except ValueError as e:
|
|
raise TApplicationException(
|
|
TApplicationException.UNKNOWN_METHOD,
|
|
"missing required argument {arg} for {service}.{api}".format(
|
|
arg=e.args[0], service=self._service.__name__, api=_api
|
|
),
|
|
)
|
|
|
|
fut = await self._send(_api, **kwargs)
|
|
if fut is not None:
|
|
self._process_messages()
|
|
return await fut
|
|
|
|
async def _send(self, _api, **kwargs) -> asyncio.Future | None:
|
|
oneway = getattr(getattr(self._service, _api + "_result"), "oneway")
|
|
msg_type = TMessageType.ONEWAY if oneway else TMessageType.CALL
|
|
seq_id = self._get_seqid()
|
|
self._oprot.write_message_begin(_api, msg_type, seq_id)
|
|
args = getattr(self._service, _api + "_args")()
|
|
for k, v in kwargs.items():
|
|
setattr(args, k, v)
|
|
self._oprot.write_struct(args)
|
|
self._oprot.write_message_end()
|
|
await self._oprot.trans.flush()
|
|
log.debug("Sent seqid %d: %s", seq_id, _api)
|
|
|
|
if oneway:
|
|
return None
|
|
else:
|
|
self._open_requests[seq_id] = asyncio.Future()
|
|
return self._open_requests[seq_id]
|
|
|
|
def _process_messages(self):
|
|
if self._message_processor is None or self._message_processor.done():
|
|
self._message_processor = asyncio.create_task(self._message_handler())
|
|
|
|
async def _message_handler(self):
|
|
while self._open_requests:
|
|
fname, mtype, rseqid = await self._iprot.read_message_begin()
|
|
log.debug("Reply for seqid %d: %s %s", rseqid, fname, mtype)
|
|
fut = self._open_requests.pop(rseqid, None)
|
|
if fut is None:
|
|
log.error("Received message with unknown seqid %d", rseqid)
|
|
|
|
try:
|
|
fut.set_result(await self._process_message(fname, mtype))
|
|
except Exception as e:
|
|
fut.set_exception(e)
|
|
|
|
async def _process_message(self, fname, mtype):
|
|
"""process a single message"""
|
|
if mtype == TMessageType.EXCEPTION:
|
|
x = TApplicationException()
|
|
await self._iprot.read_struct(x)
|
|
await self._iprot.read_message_end()
|
|
raise x
|
|
result = getattr(self._service, fname + "_result")()
|
|
await self._iprot.read_struct(result)
|
|
await self._iprot.read_message_end()
|
|
|
|
if hasattr(result, "success") and result.success is not None:
|
|
return result.success
|
|
|
|
# void api without throws
|
|
if len(result.thrift_spec) == 0:
|
|
return
|
|
|
|
# check throws
|
|
for k, v in result.__dict__.items():
|
|
if k != "success" and v:
|
|
raise v
|
|
|
|
# no throws & not void api
|
|
if hasattr(result, "success"):
|
|
raise TApplicationException(TApplicationException.MISSING_RESULT)
|
|
|
|
def _get_seqid(self) -> int:
|
|
seq_id = self._seqid
|
|
self._seqid += 1
|
|
return seq_id
|
|
|
|
def close(self):
|
|
if self._message_processor is not None and not self._message_processor.done():
|
|
self._message_processor.cancel()
|
|
self._message_processor = None
|
|
super().close()
|