- feat(nexus): Slightly refactor bot

- feat(nexus): Add vote to logging
1 internal commit(s)

GitOrigin-RevId: 8686e431de9a7af6c6763fba2b50d62d6275667b
This commit is contained in:
the-superpirate 2021-04-13 12:03:44 +03:00
parent b4a624fa12
commit 422959914c
6 changed files with 177 additions and 97 deletions

View File

@ -1,6 +1,21 @@
import logging
from izihawa_utils.exceptions import BaseError
class BannedUserError(BaseError):
level = logging.WARNING
code = 'banned_user_error'
def __init__(self, ban_timeout: int):
self.ban_timeout = ban_timeout
class MessageHasBeenDeletedError(BaseError):
level = logging.WARNING
code = 'message_has_been_deleted_error'
class UnknownFileFormatError(BaseError):
code = 'unknown_file_format_error'

View File

@ -15,12 +15,17 @@ from .admin import BaseAdminHandler
class BanHandler(BaseAdminHandler):
filter = events.NewMessage(incoming=True, pattern='^/ban ([0-9]+) ([A-Za-z0-9]+)\\s?(.*)?$')
async def handler(self, event: events.ChatAction, request_context: RequestContext):
def parse_pattern(self, event: events.ChatAction):
chat_id = int(event.pattern_match.group(1))
ban_duration = event.pattern_match.group(2)
ban_message = event.pattern_match.group(3)
ban_end_date = datetime.utcnow() + timedelta(seconds=timeparse(ban_duration))
return chat_id, ban_duration, ban_message, ban_end_date
async def handler(self, event: events.ChatAction, request_context: RequestContext):
chat_id, ban_duration, ban_message, ban_end_date = self.parse_pattern(event)
try:
await self.application.idm_client.update_chat(
chat_id=chat_id,

View File

@ -12,13 +12,18 @@ class DownloadHandler(BaseCallbackQueryHandler):
filter = events.CallbackQuery(pattern='^/dl([abcm])_([A-Za-z0-9]+)_([0-9]+)_([0-9]+)$')
is_group_handler = True
async def handler(self, event: events.ChatAction, request_context: RequestContext):
def parse_pattern(self, event: events.ChatAction):
short_schema = event.pattern_match.group(1).decode()
schema = self.short_schema_to_schema(short_schema)
session_id = event.pattern_match.group(2).decode()
document_id = int(event.pattern_match.group(3))
position = int(event.pattern_match.group(4).decode())
return short_schema, schema, session_id, document_id, position
async def handler(self, event: events.ChatAction, request_context: RequestContext):
short_schema, schema, session_id, document_id, position = self.parse_pattern(event)
self.application.user_manager.last_widget[request_context.chat.chat_id] = None
request_context.add_default_fields(mode='download', session_id=session_id)

View File

@ -1,11 +1,14 @@
import asyncio
import logging
import re
import time
from grpc import StatusCode
from grpc.experimental.aio import AioRpcError
from library.telegram.base import RequestContext
from nexus.bot.exceptions import (
BannedUserError,
MessageHasBeenDeletedError,
)
from nexus.bot.widgets.search_widget import SearchWidget
from nexus.translations import t
from nexus.views.telegram.common import close_button
@ -141,40 +144,41 @@ class SearchHandler(BaseSearchHandler):
should_reset_last_widget = False
is_subscription_required_for_handler = True
async def ban_handler(self, event: events.ChatAction, request_context: RequestContext, ban_timeout: float):
logging.getLogger('statbox').info({
'bot_name': self.application.config['telegram']['bot_name'],
'action': 'user_flood_ban',
'mode': 'search',
'ban_timeout_seconds': ban_timeout,
'chat_id': request_context.chat.chat_id,
})
ban_reason = t(
'BAN_MESSAGE_TOO_MANY_REQUESTS',
language=request_context.chat.language
)
def check_search_ban_timeout(self, chat_id: int):
ban_timeout = self.application.user_manager.check_search_ban_timeout(user_id=chat_id)
if ban_timeout:
raise BannedUserError(ban_timeout=ban_timeout)
self.application.user_manager.add_search_time(user_id=chat_id, search_time=time.time())
def parse_pattern(self, event: events.ChatAction):
search_prefix = event.pattern_match.group(1)
query = event.pattern_match.group(2)
is_group_mode = event.is_group or event.is_channel
return search_prefix, query, is_group_mode
async def handler(self, event: events.ChatAction, request_context: RequestContext):
try:
self.check_search_ban_timeout(chat_id=request_context.chat.chat_id)
except BannedUserError as e:
request_context.error_log(e)
return await event.reply(t(
'BANNED_FOR_SECONDS',
language=request_context.chat.language
).format(
seconds=str(ban_timeout),
reason=ban_reason,
seconds=e.ban_timeout,
reason=t(
'BAN_MESSAGE_TOO_MANY_REQUESTS',
language=request_context.chat.language
),
))
async def handler(self, event: events.ChatAction, request_context: RequestContext):
ban_timeout = self.application.user_manager.check_search_ban_timeout(user_id=request_context.chat.chat_id)
if ban_timeout:
return await self.ban_handler(event, request_context, ban_timeout)
self.application.user_manager.add_search_time(user_id=request_context.chat.chat_id, search_time=time.time())
search_prefix = event.pattern_match.group(1)
query = event.pattern_match.group(2)
is_group_mode = event.is_group or event.is_channel
search_prefix, query, is_group_mode = self.parse_pattern(event)
if is_group_mode and not search_prefix:
return
if not is_group_mode and search_prefix:
query = event.raw_text
prefetch_message = await event.reply(
t("SEARCHING", language=request_context.chat.language),
)
@ -199,35 +203,43 @@ class SearchEditHandler(BaseSearchHandler):
is_group_handler = True
should_reset_last_widget = False
async def handler(self, event: events.ChatAction, request_context: RequestContext):
request_context.add_default_fields(mode='search_edit')
def parse_pattern(self, event: events.ChatAction):
search_prefix = event.pattern_match.group(1)
query = event.pattern_match.group(2)
is_group_mode = event.is_group or event.is_channel
return search_prefix, query, is_group_mode
async def get_last_messages_in_chat(self, event: events.ChatAction):
return await self.application.telegram_client(functions.messages.GetMessagesRequest(
id=list(range(event.id + 1, event.id + 10)))
)
async def handler(self, event: events.ChatAction, request_context: RequestContext):
search_prefix, query, is_group_mode = self.parse_pattern(event)
request_context.add_default_fields(mode='search_edit')
if is_group_mode and not search_prefix:
return
if not is_group_mode and search_prefix:
query = event.raw_text
result = await self.application.telegram_client(functions.messages.GetMessagesRequest(
id=list(range(event.id + 1, event.id + 10)))
)
if not result:
request_context.statbox(action='failed')
return await event.reply(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
for next_message in result.messages:
last_messages = await self.get_last_messages_in_chat(event)
try:
if not last_messages:
raise MessageHasBeenDeletedError()
for next_message in last_messages.messages:
if next_message.is_reply and event.id == next_message.reply_to_msg_id:
request_context.statbox(action='resolved')
await self.do_search(
return await self.do_search(
event,
request_context,
prefetch_message=next_message,
query=query,
is_group_mode=is_group_mode,
)
return
request_context.statbox(action='failed')
raise MessageHasBeenDeletedError()
except MessageHasBeenDeletedError as e:
request_context.error_log(e)
return await event.reply(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
@ -243,19 +255,17 @@ class SearchPagingHandler(BaseCallbackQueryHandler):
page = int(event.pattern_match.group(3).decode())
request_context.add_default_fields(mode='search_paging', session_id=session_id)
start_time = time.time()
message = await event.get_message()
if not message:
return await event.answer()
reply_message = await message.get_reply_message()
if not reply_message:
return await event.respond(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
start_time = time.time()
query = reply_message.raw_text
reply_message = await message.get_reply_message()
try:
if not reply_message:
raise MessageHasBeenDeletedError()
query = reply_message.raw_text
search_widget = await SearchWidget.create(
application=self.application,
chat=request_context.chat,
@ -265,6 +275,10 @@ class SearchPagingHandler(BaseCallbackQueryHandler):
query=query,
page=page,
)
except MessageHasBeenDeletedError:
return await event.respond(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
except AioRpcError as e:
if e.code() == StatusCode.INVALID_ARGUMENT or e.code() == StatusCode.CANCELLED:
request_context.error_log(e)

View File

@ -2,6 +2,7 @@ import asyncio
import re
from library.telegram.base import RequestContext
from nexus.bot.exceptions import MessageHasBeenDeletedError
from nexus.translations import t
from telethon import (
events,
@ -17,7 +18,7 @@ class ViewHandler(BaseHandler):
'([0-9]+)')
should_reset_last_widget = False
async def handler(self, event: events.ChatAction, request_context: RequestContext):
def parse_pattern(self, event: events.ChatAction):
short_schema = event.pattern_match.group(1)
parent_view_type = event.pattern_match.group(2) or 's'
schema = self.short_schema_to_schema(short_schema)
@ -28,12 +29,10 @@ class ViewHandler(BaseHandler):
page = int(position / self.application.config['application']['page_size'])
request_context.add_default_fields(mode='view', session_id=session_id)
request_context.statbox(action='view', document_id=document_id, position=position, schema=schema)
found_old_widget = old_message_id == self.application.user_manager.last_widget.get(request_context.chat.chat_id)
return parent_view_type, schema, session_id, old_message_id, document_id, position, page
try:
if found_old_widget:
async def process_widgeting(self, has_found_old_widget, old_message_id, request_context: RequestContext):
if has_found_old_widget:
message_id = old_message_id
link_preview = None
else:
@ -48,15 +47,16 @@ class ViewHandler(BaseHandler):
self.application.user_manager.last_widget[request_context.chat.chat_id] = prefetch_message.id
message_id = prefetch_message.id
link_preview = True
return message_id, link_preview
document_view = await self.resolve_document(
schema,
document_id,
position,
async def compose_back_command(
self,
parent_view_type,
session_id,
request_context,
)
old_message_id,
message_id,
page,
):
back_command = None
if parent_view_type == 's':
back_command = f'/search_{session_id}_{message_id}_{page}'
@ -65,13 +65,48 @@ class ViewHandler(BaseHandler):
functions.messages.GetMessagesRequest(id=[old_message_id])
)).messages
if not messages:
return await event.respond(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
raise MessageHasBeenDeletedError()
message = messages[0]
referencing_to = re.search(r'Linked to: ([0-9]+)', message.raw_text).group(1)
back_command = f'/rp_{session_id}_{message_id}_{referencing_to}_{page}'
return back_command
async def handler(self, event: events.ChatAction, request_context: RequestContext):
parent_view_type, schema, session_id, old_message_id, document_id, position, page = self.parse_pattern(event)
request_context.add_default_fields(mode='view', session_id=session_id)
request_context.statbox(action='view', document_id=document_id, position=position, schema=schema)
has_found_old_widget = old_message_id == self.application.user_manager.last_widget.get(request_context.chat.chat_id)
try:
message_id, link_preview = await self.process_widgeting(
has_found_old_widget=has_found_old_widget,
old_message_id=old_message_id,
request_context=request_context
)
document_view = await self.resolve_document(
schema,
document_id,
position,
session_id,
request_context,
)
try:
back_command = await self.compose_back_command(
parent_view_type=parent_view_type,
session_id=session_id,
old_message_id=old_message_id,
message_id=message_id,
page=page,
)
except MessageHasBeenDeletedError:
return await event.respond(
t('REPLY_MESSAGE_HAS_BEEN_DELETED', language=request_context.chat.language),
)
view, buttons = document_view.get_view(
language=request_context.chat.language,
session_id=session_id,
@ -89,7 +124,7 @@ class ViewHandler(BaseHandler):
),
event.delete(),
]
if not found_old_widget:
if not has_found_old_widget:
actions.append(
self.application.telegram_client.delete_messages(
request_context.chat.chat_id,

View File

@ -15,14 +15,26 @@ from .base import BaseCallbackQueryHandler
class VoteHandler(BaseCallbackQueryHandler):
filter = events.CallbackQuery(pattern='^/vote([ab])?_([A-Za-z0-9]+)_([0-9]+)_([bo])$')
async def handler(self, event: events.ChatAction, request_context: RequestContext):
def parse_pattern(self, event: events.ChatAction):
short_schema = event.pattern_match.group(1)
schema = self.short_schema_to_schema(short_schema) if short_schema else None
schema = self.short_schema_to_schema(short_schema.decode()) if short_schema else None
session_id = event.pattern_match.group(2).decode()
document_id = int(event.pattern_match.group(3).decode())
vote = event.pattern_match.group(4).decode()
vote_value = {'b': -1, 'o': 1}[vote]
return schema, session_id, document_id, vote, vote_value
async def handler(self, event: events.ChatAction, request_context: RequestContext):
schema, session_id, document_id, vote, vote_value = self.parse_pattern(event)
request_context.add_default_fields(mode='vote', session_id=session_id)
request_context.statbox(
action='vote',
document_id=document_id,
query=vote,
schema=schema,
)
document_operation_pb = DocumentOperationPb(
vote=VotePb(
@ -31,12 +43,6 @@ class VoteHandler(BaseCallbackQueryHandler):
voter_id=request_context.chat.chat_id,
),
)
request_context.statbox(
action='vote',
document_id=document_id,
schema=schema,
)
logging.getLogger('operation').info(
msg=MessageToDict(document_operation_pb),
)