- feat(nexus): Slightly refactor bot

- feat(nexus): Add vote to logging
1 internal commit(s)

GitOrigin-RevId: 8686e431de9a7af6c6763fba2b50d62d6275667b
This commit is contained in:
the-superpirate 2021-04-13 12:03:44 +03:00
parent b4a624fa12
commit 422959914c
6 changed files with 177 additions and 97 deletions

View File

@ -1,6 +1,21 @@
import logging
from izihawa_utils.exceptions import BaseError from izihawa_utils.exceptions import BaseError
class BannedUserError(BaseError):
level = logging.WARNING
code = 'banned_user_error'
def __init__(self, ban_timeout: int):
self.ban_timeout = ban_timeout
class MessageHasBeenDeletedError(BaseError):
level = logging.WARNING
code = 'message_has_been_deleted_error'
class UnknownFileFormatError(BaseError): class UnknownFileFormatError(BaseError):
code = 'unknown_file_format_error' code = 'unknown_file_format_error'

View File

@ -15,12 +15,17 @@ from .admin import BaseAdminHandler
class BanHandler(BaseAdminHandler): class BanHandler(BaseAdminHandler):
filter = events.NewMessage(incoming=True, pattern='^/ban ([0-9]+) ([A-Za-z0-9]+)\\s?(.*)?$') filter = events.NewMessage(incoming=True, pattern='^/ban ([0-9]+) ([A-Za-z0-9]+)\\s?(.*)?$')
async def handler(self, event: events.ChatAction, request_context: RequestContext): def parse_pattern(self, event: events.ChatAction):
chat_id = int(event.pattern_match.group(1)) chat_id = int(event.pattern_match.group(1))
ban_duration = event.pattern_match.group(2) ban_duration = event.pattern_match.group(2)
ban_message = event.pattern_match.group(3) ban_message = event.pattern_match.group(3)
ban_end_date = datetime.utcnow() + timedelta(seconds=timeparse(ban_duration)) ban_end_date = datetime.utcnow() + timedelta(seconds=timeparse(ban_duration))
return chat_id, ban_duration, ban_message, ban_end_date
async def handler(self, event: events.ChatAction, request_context: RequestContext):
chat_id, ban_duration, ban_message, ban_end_date = self.parse_pattern(event)
try: try:
await self.application.idm_client.update_chat( await self.application.idm_client.update_chat(
chat_id=chat_id, chat_id=chat_id,

View File

@ -12,13 +12,18 @@ class DownloadHandler(BaseCallbackQueryHandler):
filter = events.CallbackQuery(pattern='^/dl([abcm])_([A-Za-z0-9]+)_([0-9]+)_([0-9]+)$') filter = events.CallbackQuery(pattern='^/dl([abcm])_([A-Za-z0-9]+)_([0-9]+)_([0-9]+)$')
is_group_handler = True is_group_handler = True
async def handler(self, event: events.ChatAction, request_context: RequestContext): def parse_pattern(self, event: events.ChatAction):
short_schema = event.pattern_match.group(1).decode() short_schema = event.pattern_match.group(1).decode()
schema = self.short_schema_to_schema(short_schema) schema = self.short_schema_to_schema(short_schema)
session_id = event.pattern_match.group(2).decode() session_id = event.pattern_match.group(2).decode()
document_id = int(event.pattern_match.group(3)) document_id = int(event.pattern_match.group(3))
position = int(event.pattern_match.group(4).decode()) position = int(event.pattern_match.group(4).decode())
return short_schema, schema, session_id, document_id, position
async def handler(self, event: events.ChatAction, request_context: RequestContext):
short_schema, schema, session_id, document_id, position = self.parse_pattern(event)
self.application.user_manager.last_widget[request_context.chat.chat_id] = None self.application.user_manager.last_widget[request_context.chat.chat_id] = None
request_context.add_default_fields(mode='download', session_id=session_id) request_context.add_default_fields(mode='download', session_id=session_id)

View File

@ -1,11 +1,14 @@
import asyncio import asyncio
import logging
import re import re
import time import time
from grpc import StatusCode from grpc import StatusCode
from grpc.experimental.aio import AioRpcError from grpc.experimental.aio import AioRpcError
from library.telegram.base import RequestContext from library.telegram.base import RequestContext
from nexus.bot.exceptions import (
BannedUserError,
MessageHasBeenDeletedError,
)
from nexus.bot.widgets.search_widget import SearchWidget from nexus.bot.widgets.search_widget import SearchWidget
from nexus.translations import t from nexus.translations import t
from nexus.views.telegram.common import close_button from nexus.views.telegram.common import close_button
@ -141,40 +144,41 @@ class SearchHandler(BaseSearchHandler):
should_reset_last_widget = False should_reset_last_widget = False
is_subscription_required_for_handler = True is_subscription_required_for_handler = True
async def ban_handler(self, event: events.ChatAction, request_context: RequestContext, ban_timeout: float): def check_search_ban_timeout(self, chat_id: int):
logging.getLogger('statbox').info({ ban_timeout = self.application.user_manager.check_search_ban_timeout(user_id=chat_id)
'bot_name': self.application.config['telegram']['bot_name'], if ban_timeout:
'action': 'user_flood_ban', raise BannedUserError(ban_timeout=ban_timeout)
'mode': 'search', self.application.user_manager.add_search_time(user_id=chat_id, search_time=time.time())
'ban_timeout_seconds': ban_timeout,
'chat_id': request_context.chat.chat_id, def parse_pattern(self, event: events.ChatAction):
}) search_prefix = event.pattern_match.group(1)
ban_reason = t( query = event.pattern_match.group(2)
'BAN_MESSAGE_TOO_MANY_REQUESTS', is_group_mode = event.is_group or event.is_channel
language=request_context.chat.language
) return search_prefix, query, is_group_mode
async def handler(self, event: events.ChatAction, request_context: RequestContext):
try:
self.check_search_ban_timeout(chat_id=request_context.chat.chat_id)
except BannedUserError as e:
request_context.error_log(e)
return await event.reply(t( return await event.reply(t(
'BANNED_FOR_SECONDS', 'BANNED_FOR_SECONDS',
language=request_context.chat.language language=request_context.chat.language
).format( ).format(
seconds=str(ban_timeout), seconds=e.ban_timeout,
reason=ban_reason, reason=t(
'BAN_MESSAGE_TOO_MANY_REQUESTS',
language=request_context.chat.language
),
)) ))
search_prefix, query, is_group_mode = self.parse_pattern(event)
async def handler(self, event: events.ChatAction, request_context: RequestContext):
ban_timeout = self.application.user_manager.check_search_ban_timeout(user_id=request_context.chat.chat_id)
if ban_timeout:
return await self.ban_handler(event, request_context, ban_timeout)
self.application.user_manager.add_search_time(user_id=request_context.chat.chat_id, search_time=time.time())
search_prefix = event.pattern_match.group(1)
query = event.pattern_match.group(2)
is_group_mode = event.is_group or event.is_channel
if is_group_mode and not search_prefix: if is_group_mode and not search_prefix:
return return
if not is_group_mode and search_prefix: if not is_group_mode and search_prefix:
query = event.raw_text query = event.raw_text
prefetch_message = await event.reply( prefetch_message = await event.reply(
t("SEARCHING", language=request_context.chat.language), t("SEARCHING", language=request_context.chat.language),
) )
@ -199,35 +203,43 @@ class SearchEditHandler(BaseSearchHandler):
is_group_handler = True is_group_handler = True
should_reset_last_widget = False should_reset_last_widget = False
async def handler(self, event: events.ChatAction, request_context: RequestContext): def parse_pattern(self, event: events.ChatAction):
request_context.add_default_fields(mode='search_edit')
search_prefix = event.pattern_match.group(1) search_prefix = event.pattern_match.group(1)
query = event.pattern_match.group(2) query = event.pattern_match.group(2)
is_group_mode = event.is_group or event.is_channel is_group_mode = event.is_group or event.is_channel
return search_prefix, query, is_group_mode
async def get_last_messages_in_chat(self, event: events.ChatAction):
return await self.application.telegram_client(functions.messages.GetMessagesRequest(
id=list(range(event.id + 1, event.id + 10)))
)
async def handler(self, event: events.ChatAction, request_context: RequestContext):
search_prefix, query, is_group_mode = self.parse_pattern(event)
request_context.add_default_fields(mode='search_edit')
if is_group_mode and not search_prefix: if is_group_mode and not search_prefix:
return return
if not is_group_mode and search_prefix: if not is_group_mode and search_prefix:
query = event.raw_text query = event.raw_text
result = await self.application.telegram_client(functions.messages.GetMessagesRequest(
id=list(range(event.id + 1, event.id + 10))) last_messages = await self.get_last_messages_in_chat(event)
) try:
if not result: if not last_messages:
request_context.statbox(action='failed') raise MessageHasBeenDeletedError()
return await event.reply( for next_message in last_messages.messages:
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
for next_message in result.messages:
if next_message.is_reply and event.id == next_message.reply_to_msg_id: if next_message.is_reply and event.id == next_message.reply_to_msg_id:
request_context.statbox(action='resolved') request_context.statbox(action='resolved')
await self.do_search( return await self.do_search(
event, event,
request_context, request_context,
prefetch_message=next_message, prefetch_message=next_message,
query=query, query=query,
is_group_mode=is_group_mode, is_group_mode=is_group_mode,
) )
return raise MessageHasBeenDeletedError()
request_context.statbox(action='failed') except MessageHasBeenDeletedError as e:
request_context.error_log(e)
return await event.reply( return await event.reply(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language), t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
) )
@ -243,19 +255,17 @@ class SearchPagingHandler(BaseCallbackQueryHandler):
page = int(event.pattern_match.group(3).decode()) page = int(event.pattern_match.group(3).decode())
request_context.add_default_fields(mode='search_paging', session_id=session_id) request_context.add_default_fields(mode='search_paging', session_id=session_id)
start_time = time.time()
message = await event.get_message() message = await event.get_message()
if not message: if not message:
return await event.answer() return await event.answer()
reply_message = await message.get_reply_message()
if not reply_message:
return await event.respond(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
start_time = time.time() reply_message = await message.get_reply_message()
query = reply_message.raw_text
try: try:
if not reply_message:
raise MessageHasBeenDeletedError()
query = reply_message.raw_text
search_widget = await SearchWidget.create( search_widget = await SearchWidget.create(
application=self.application, application=self.application,
chat=request_context.chat, chat=request_context.chat,
@ -265,6 +275,10 @@ class SearchPagingHandler(BaseCallbackQueryHandler):
query=query, query=query,
page=page, page=page,
) )
except MessageHasBeenDeletedError:
return await event.respond(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
except AioRpcError as e: except AioRpcError as e:
if e.code() == StatusCode.INVALID_ARGUMENT or e.code() == StatusCode.CANCELLED: if e.code() == StatusCode.INVALID_ARGUMENT or e.code() == StatusCode.CANCELLED:
request_context.error_log(e) request_context.error_log(e)

View File

@ -2,6 +2,7 @@ import asyncio
import re import re
from library.telegram.base import RequestContext from library.telegram.base import RequestContext
from nexus.bot.exceptions import MessageHasBeenDeletedError
from nexus.translations import t from nexus.translations import t
from telethon import ( from telethon import (
events, events,
@ -17,7 +18,7 @@ class ViewHandler(BaseHandler):
'([0-9]+)') '([0-9]+)')
should_reset_last_widget = False should_reset_last_widget = False
async def handler(self, event: events.ChatAction, request_context: RequestContext): def parse_pattern(self, event: events.ChatAction):
short_schema = event.pattern_match.group(1) short_schema = event.pattern_match.group(1)
parent_view_type = event.pattern_match.group(2) or 's' parent_view_type = event.pattern_match.group(2) or 's'
schema = self.short_schema_to_schema(short_schema) schema = self.short_schema_to_schema(short_schema)
@ -28,12 +29,10 @@ class ViewHandler(BaseHandler):
page = int(position / self.application.config['application']['page_size']) page = int(position / self.application.config['application']['page_size'])
request_context.add_default_fields(mode='view', session_id=session_id) return parent_view_type, schema, session_id, old_message_id, document_id, position, page
request_context.statbox(action='view', document_id=document_id, position=position, schema=schema)
found_old_widget = old_message_id == self.application.user_manager.last_widget.get(request_context.chat.chat_id)
try: async def process_widgeting(self, has_found_old_widget, old_message_id, request_context: RequestContext):
if found_old_widget: if has_found_old_widget:
message_id = old_message_id message_id = old_message_id
link_preview = None link_preview = None
else: else:
@ -48,15 +47,16 @@ class ViewHandler(BaseHandler):
self.application.user_manager.last_widget[request_context.chat.chat_id] = prefetch_message.id self.application.user_manager.last_widget[request_context.chat.chat_id] = prefetch_message.id
message_id = prefetch_message.id message_id = prefetch_message.id
link_preview = True link_preview = True
return message_id, link_preview
document_view = await self.resolve_document( async def compose_back_command(
schema, self,
document_id, parent_view_type,
position,
session_id, session_id,
request_context, old_message_id,
) message_id,
page,
):
back_command = None back_command = None
if parent_view_type == 's': if parent_view_type == 's':
back_command = f'/search_{session_id}_{message_id}_{page}' back_command = f'/search_{session_id}_{message_id}_{page}'
@ -65,13 +65,48 @@ class ViewHandler(BaseHandler):
functions.messages.GetMessagesRequest(id=[old_message_id]) functions.messages.GetMessagesRequest(id=[old_message_id])
)).messages )).messages
if not messages: if not messages:
return await event.respond( raise MessageHasBeenDeletedError()
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
message = messages[0] message = messages[0]
referencing_to = re.search(r'Linked to: ([0-9]+)', message.raw_text).group(1) referencing_to = re.search(r'Linked to: ([0-9]+)', message.raw_text).group(1)
back_command = f'/rp_{session_id}_{message_id}_{referencing_to}_{page}' back_command = f'/rp_{session_id}_{message_id}_{referencing_to}_{page}'
return back_command
async def handler(self, event: events.ChatAction, request_context: RequestContext):
parent_view_type, schema, session_id, old_message_id, document_id, position, page = self.parse_pattern(event)
request_context.add_default_fields(mode='view', session_id=session_id)
request_context.statbox(action='view', document_id=document_id, position=position, schema=schema)
has_found_old_widget = old_message_id == self.application.user_manager.last_widget.get(request_context.chat.chat_id)
try:
message_id, link_preview = await self.process_widgeting(
has_found_old_widget=has_found_old_widget,
old_message_id=old_message_id,
request_context=request_context
)
document_view = await self.resolve_document(
schema,
document_id,
position,
session_id,
request_context,
)
try:
back_command = await self.compose_back_command(
parent_view_type=parent_view_type,
session_id=session_id,
old_message_id=old_message_id,
message_id=message_id,
page=page,
)
except MessageHasBeenDeletedError:
return await event.respond(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
view, buttons = document_view.get_view( view, buttons = document_view.get_view(
language=request_context.chat.language, language=request_context.chat.language,
session_id=session_id, session_id=session_id,
@ -89,7 +124,7 @@ class ViewHandler(BaseHandler):
), ),
event.delete(), event.delete(),
] ]
if not found_old_widget: if not has_found_old_widget:
actions.append( actions.append(
self.application.telegram_client.delete_messages( self.application.telegram_client.delete_messages(
request_context.chat.chat_id, request_context.chat.chat_id,

View File

@ -15,14 +15,26 @@ from .base import BaseCallbackQueryHandler
class VoteHandler(BaseCallbackQueryHandler): class VoteHandler(BaseCallbackQueryHandler):
filter = events.CallbackQuery(pattern='^/vote([ab])?_([A-Za-z0-9]+)_([0-9]+)_([bo])$') filter = events.CallbackQuery(pattern='^/vote([ab])?_([A-Za-z0-9]+)_([0-9]+)_([bo])$')
async def handler(self, event: events.ChatAction, request_context: RequestContext): def parse_pattern(self, event: events.ChatAction):
short_schema = event.pattern_match.group(1) short_schema = event.pattern_match.group(1)
schema = self.short_schema_to_schema(short_schema) if short_schema else None schema = self.short_schema_to_schema(short_schema.decode()) if short_schema else None
session_id = event.pattern_match.group(2).decode() session_id = event.pattern_match.group(2).decode()
document_id = int(event.pattern_match.group(3).decode()) document_id = int(event.pattern_match.group(3).decode())
vote = event.pattern_match.group(4).decode() vote = event.pattern_match.group(4).decode()
vote_value = {'b': -1, 'o': 1}[vote] vote_value = {'b': -1, 'o': 1}[vote]
return schema, session_id, document_id, vote, vote_value
async def handler(self, event: events.ChatAction, request_context: RequestContext):
schema, session_id, document_id, vote, vote_value = self.parse_pattern(event)
request_context.add_default_fields(mode='vote', session_id=session_id) request_context.add_default_fields(mode='vote', session_id=session_id)
request_context.statbox(
action='vote',
document_id=document_id,
query=vote,
schema=schema,
)
document_operation_pb = DocumentOperationPb( document_operation_pb = DocumentOperationPb(
vote=VotePb( vote=VotePb(
@ -31,12 +43,6 @@ class VoteHandler(BaseCallbackQueryHandler):
voter_id=request_context.chat.chat_id, voter_id=request_context.chat.chat_id,
), ),
) )
request_context.statbox(
action='vote',
document_id=document_id,
schema=schema,
)
logging.getLogger('operation').info( logging.getLogger('operation').info(
msg=MessageToDict(document_operation_pb), msg=MessageToDict(document_operation_pb),
) )