153 lines
4.8 KiB
Python
Raw Normal View History

from __future__ import annotations
import asyncio
import logging
from typing import (
List,
Union,
)
import orjson as json
from aiokafka import AIOKafkaConsumer
from aiokafka.errors import (
CommitFailedError,
ConsumerStoppedError,
)
from aiokit import AioRootThing
from google.protobuf.json_format import ParseDict
from nexus.actions.exceptions import (
ConflictError,
InterruptProcessing,
)
from nexus.pipe.processors.base import Processor
class BaseConsumer(AioRootThing):
def __init__(self, processors: List[Processor],
topic_names: Union[str, List[str]], bootstrap_servers: str, group_id: str):
super().__init__()
self.processors = processors
if isinstance(topic_names, str):
topic_names = [topic_names]
self.topic_names = topic_names
self.bootstrap_servers = bootstrap_servers
self.group_id = group_id
self.consumer = None
self.starts.extend(self.processors)
def create_consumer(self):
return AIOKafkaConsumer(
*self.topic_names,
auto_offset_reset='earliest',
loop=asyncio.get_event_loop(),
bootstrap_servers=self.bootstrap_servers,
group_id=self.group_id,
enable_auto_commit=False,
)
def preprocess(self, msg):
return msg
async def start(self):
logging.getLogger('statbox').info({
'action': 'starting',
'group_id': self.group_id,
'topic_names': self.topic_names,
})
self.consumer = self.create_consumer()
await self.consumer.start()
logging.getLogger('statbox').info({
'action': 'started',
'group_id': self.group_id,
'topic_names': self.topic_names,
})
try:
async for msg in self.consumer:
preprocessed_msg = self.preprocess(msg)
if preprocessed_msg:
for processor in self.processors:
if not processor.filter(preprocessed_msg):
continue
try:
await processor.process(preprocessed_msg)
except (ConflictError, InterruptProcessing) as e:
logging.getLogger('statbox').info(e)
except Exception as e:
logging.getLogger('error').error(e)
raise
try:
await self.consumer.commit()
except CommitFailedError as e:
logging.getLogger('error').error(e)
except ConsumerStoppedError:
pass
async def stop(self):
if not self.consumer:
return
await self.consumer.stop()
class BasePbConsumer(BaseConsumer):
pb_class = None
def preprocess(self, msg) -> pb_class:
pb = self.pb_class()
pb.ParseFromString(msg.value)
return pb
class BaseJsonConsumer(BaseConsumer):
pb_class = None
def preprocess(self, msg) -> pb_class:
pb = self.pb_class()
message = json.loads(msg.value)
ParseDict(message, pb, ignore_unknown_fields=True)
return pb
class BaseBulkConsumer(BaseConsumer):
bulk_size = 20
timeout = 1
async def start(self):
logging.getLogger('statbox').info({
'action': 'starting',
'group_id': self.group_id,
'topic_names': self.topic_names,
})
self.consumer = self.create_consumer()
await self.consumer.start()
logging.getLogger('statbox').info({
'action': 'started',
'group_id': self.group_id,
'topic_names': self.topic_names,
})
while self.started:
try:
result = await self.consumer.getmany(timeout_ms=self.timeout * 1000, max_records=self.bulk_size)
except ConsumerStoppedError:
break
collector = []
for tp, messages in result.items():
if messages:
for message in messages:
preprocessed_msg = self.preprocess(message)
if preprocessed_msg:
collector.append(preprocessed_msg)
for processor in self.processors:
filtered = filter(processor.filter, collector)
try:
await processor.process_bulk(filtered)
except InterruptProcessing as e:
logging.getLogger('statbox').info(e)
except Exception as e:
logging.getLogger('error').error(e)
raise
try:
await self.consumer.commit()
except CommitFailedError as e:
logging.getLogger('error').error(e)
continue