mirror of
https://github.com/nexus-stc/hyperboria
synced 2025-01-18 06:27:34 +01:00
8472f27ec5
GitOrigin-RevId: ddf02e70d2827c048db49b687ebbcdcc67807ca6
143 lines
4.5 KiB
Python
143 lines
4.5 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import (
|
|
List,
|
|
Union,
|
|
)
|
|
|
|
import orjson as json
|
|
from aiokafka import AIOKafkaConsumer
|
|
from aiokafka.errors import (
|
|
CommitFailedError,
|
|
ConsumerStoppedError,
|
|
)
|
|
from aiokit import AioRootThing
|
|
from google.protobuf.json_format import ParseDict
|
|
from nexus.actions.exceptions import (
|
|
ConflictError,
|
|
InterruptProcessing,
|
|
)
|
|
from nexus.pipe.processors.base import Processor
|
|
|
|
|
|
class BaseConsumer(AioRootThing):
|
|
def __init__(self, processors: List[Processor],
|
|
topic_names: Union[str, List[str]], bootstrap_servers: str, group_id: str):
|
|
super().__init__()
|
|
self.processors = processors
|
|
if isinstance(topic_names, str):
|
|
topic_names = [topic_names]
|
|
self.topic_names = topic_names
|
|
self.bootstrap_servers = bootstrap_servers
|
|
self.group_id = group_id
|
|
self.consumer = None
|
|
self.starts.extend(self.processors)
|
|
|
|
def create_consumer(self):
|
|
return AIOKafkaConsumer(
|
|
*self.topic_names,
|
|
auto_offset_reset='earliest',
|
|
loop=asyncio.get_event_loop(),
|
|
bootstrap_servers=self.bootstrap_servers,
|
|
group_id=self.group_id,
|
|
enable_auto_commit=False,
|
|
)
|
|
|
|
def preprocess(self, msg):
|
|
return msg
|
|
|
|
async def start(self):
|
|
logging.getLogger('statbox').info({
|
|
'action': 'started',
|
|
'group_id': self.group_id,
|
|
'topic_names': self.topic_names,
|
|
})
|
|
self.consumer = self.create_consumer()
|
|
await self.consumer.start()
|
|
try:
|
|
async for msg in self.consumer:
|
|
preprocessed_msg = self.preprocess(msg)
|
|
if preprocessed_msg:
|
|
for processor in self.processors:
|
|
if not processor.filter(preprocessed_msg):
|
|
continue
|
|
try:
|
|
await processor.process(preprocessed_msg)
|
|
except (ConflictError, InterruptProcessing) as e:
|
|
logging.getLogger('statbox').info(e)
|
|
except Exception as e:
|
|
logging.getLogger('error').error(e)
|
|
raise
|
|
try:
|
|
await self.consumer.commit()
|
|
except CommitFailedError as e:
|
|
logging.getLogger('error').error(e)
|
|
except ConsumerStoppedError:
|
|
pass
|
|
|
|
async def stop(self):
|
|
if not self.consumer:
|
|
return
|
|
await self.consumer.stop()
|
|
|
|
|
|
class BasePbConsumer(BaseConsumer):
|
|
pb_class = None
|
|
|
|
def preprocess(self, msg) -> pb_class:
|
|
pb = self.pb_class()
|
|
pb.ParseFromString(msg.value)
|
|
return pb
|
|
|
|
|
|
class BaseJsonConsumer(BaseConsumer):
|
|
pb_class = None
|
|
|
|
def preprocess(self, msg) -> pb_class:
|
|
pb = self.pb_class()
|
|
message = json.loads(msg.value)
|
|
ParseDict(message, pb, ignore_unknown_fields=True)
|
|
return pb
|
|
|
|
|
|
class BaseBulkConsumer(BaseConsumer):
|
|
bulk_size = 20
|
|
timeout = 1
|
|
|
|
async def start(self):
|
|
logging.getLogger('statbox').info({
|
|
'action': 'started',
|
|
'group_id': self.group_id,
|
|
'topic_names': self.topic_names,
|
|
})
|
|
self.consumer = self.create_consumer()
|
|
await self.consumer.start()
|
|
while self.started:
|
|
try:
|
|
result = await self.consumer.getmany(timeout_ms=self.timeout * 1000, max_records=self.bulk_size)
|
|
except ConsumerStoppedError:
|
|
break
|
|
collector = []
|
|
for tp, messages in result.items():
|
|
if messages:
|
|
for message in messages:
|
|
preprocessed_msg = self.preprocess(message)
|
|
if preprocessed_msg:
|
|
collector.append(preprocessed_msg)
|
|
for processor in self.processors:
|
|
filtered = filter(processor.filter, collector)
|
|
try:
|
|
await processor.process_bulk(filtered)
|
|
except InterruptProcessing as e:
|
|
logging.getLogger('statbox').info(e)
|
|
except Exception as e:
|
|
logging.getLogger('error').error(e)
|
|
raise
|
|
try:
|
|
await self.consumer.commit()
|
|
except CommitFailedError as e:
|
|
logging.getLogger('error').error(e)
|
|
continue
|