mirror of
https://github.com/nexus-stc/hyperboria
synced 2024-12-24 10:35:49 +01:00
aa819fbc1a
1 internal commit(s) GitOrigin-RevId: e4e34d0a00e117151a1fde75897567cbfc732d5f
119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
import logging
|
|
from functools import wraps
|
|
|
|
import yaml
|
|
from aiokit import (
|
|
AioRootThing,
|
|
AioThing,
|
|
)
|
|
from google.protobuf.json_format import MessageToDict
|
|
from grpc import aio
|
|
from izihawa_utils.text import camel_to_snake
|
|
from library.logging import error_log
|
|
|
|
|
|
class AioGrpcServer(AioRootThing):
|
|
def __init__(self, address, port):
|
|
super().__init__()
|
|
self.address = address
|
|
self.port = port
|
|
self.server = aio.server()
|
|
self.server.add_insecure_port(f'{address}:{port}')
|
|
|
|
async def start(self):
|
|
logging.getLogger('debug').info({
|
|
'action': 'starting',
|
|
'address': self.address,
|
|
'mode': 'grpc',
|
|
'port': self.port,
|
|
})
|
|
await self.server.start()
|
|
await self.server.wait_for_termination()
|
|
|
|
async def stop(self):
|
|
logging.getLogger('debug').info({
|
|
'action': 'stopping',
|
|
'mode': 'grpc',
|
|
})
|
|
await self.server.stop(None)
|
|
|
|
def log_config(self, config):
|
|
logging.getLogger('debug').info(
|
|
'\n' + yaml.safe_dump(config.get_files()),
|
|
)
|
|
|
|
|
|
class BaseService(AioThing):
|
|
error_mapping = {}
|
|
|
|
def __init__(self, service_name):
|
|
super().__init__()
|
|
self.service_name = service_name
|
|
self.class_name = camel_to_snake(self.__class__.__name__)
|
|
|
|
def get_default_service_fields(self):
|
|
return {'service_name': self.service_name, 'view': self.class_name}
|
|
|
|
def statbox(self, **kwargs):
|
|
logging.getLogger('statbox').info(self.get_default_service_fields() | kwargs)
|
|
|
|
|
|
def aiogrpc_request_wrapper(log=True):
|
|
def _aiogrpc_request_wrapper(func):
|
|
@wraps(func)
|
|
async def wrapped(self, request, context):
|
|
metadata = dict(context.invocation_metadata())
|
|
try:
|
|
if log:
|
|
self.statbox(
|
|
action='enter',
|
|
mode=func.__name__,
|
|
request_id=metadata.get('request-id'),
|
|
)
|
|
r = await func(self, request, context, metadata)
|
|
if log:
|
|
self.statbox(
|
|
action='exit',
|
|
mode=func.__name__,
|
|
request_id=metadata.get('request-id'),
|
|
)
|
|
return r
|
|
except aio.AbortError:
|
|
raise
|
|
except Exception as e:
|
|
serialized_request = MessageToDict(request, preserving_proto_field_name=True)
|
|
error_log(e, request=serialized_request, request_id=metadata.get('request-id'))
|
|
if e.__class__ in self.error_mapping:
|
|
await context.abort(*self.error_mapping[e.__class__])
|
|
raise e
|
|
return wrapped
|
|
return _aiogrpc_request_wrapper
|
|
|
|
|
|
def aiogrpc_streaming_request_wrapper(func):
|
|
@wraps(func)
|
|
async def wrapped(self, request, context):
|
|
metadata = dict(context.invocation_metadata())
|
|
try:
|
|
self.statbox(
|
|
action='enter',
|
|
mode=func.__name__,
|
|
request_id=metadata.get('request-id'),
|
|
)
|
|
async for item in func(self, request, context, metadata):
|
|
yield item
|
|
self.statbox(
|
|
action='exit',
|
|
mode=func.__name__,
|
|
request_id=metadata.get('request-id'),
|
|
)
|
|
except aio.AbortError:
|
|
raise
|
|
except Exception as e:
|
|
serialized_request = MessageToDict(request, preserving_proto_field_name=True)
|
|
error_log(e, request=serialized_request, request_id=metadata.get('request-id'))
|
|
if e.__class__ in self.error_mapping:
|
|
await context.abort(*self.error_mapping[e.__class__])
|
|
raise e
|
|
return wrapped
|