mirror of
https://github.com/nexus-stc/hyperboria
synced 2025-01-07 09:15:58 +01:00
dd23846059
- [nexus] Switch bot - [bot] Added extra receivers functionality GitOrigin-RevId: 68fc32d3e79ff411758f54f435fe8680fc42dead
349 lines
13 KiB
Python
349 lines
13 KiB
Python
import logging
|
|
import time
|
|
from abc import ABC
|
|
from datetime import datetime
|
|
from typing import Union
|
|
|
|
from grpc import StatusCode
|
|
from grpc.experimental.aio import AioRpcError
|
|
from idm.api.proto.chat_manager_service_pb2 import Chat as ChatPb
|
|
from izihawa_utils.exceptions import BaseError
|
|
from izihawa_utils.random import random_string
|
|
from library.logging import error_log
|
|
from library.telegram.base import RequestContext
|
|
from library.telegram.utils import safe_execution
|
|
from nexus.bot.application import TelegramApplication
|
|
from nexus.bot.exceptions import UnknownIndexAliasError
|
|
from nexus.models.proto.typed_document_pb2 import \
|
|
TypedDocument as TypedDocumentPb
|
|
from nexus.translations import t
|
|
from nexus.views.telegram.common import close_button
|
|
from nexus.views.telegram.scimag import ScimagView
|
|
from nexus.views.telegram.scitech import ScitechView
|
|
from telethon import (
|
|
TelegramClient,
|
|
events,
|
|
)
|
|
from telethon.errors import (
|
|
QueryIdInvalidError,
|
|
UserNotParticipantError,
|
|
)
|
|
|
|
|
|
def get_username(event: events.ChatAction, chat):
|
|
if event.is_group or event.is_channel:
|
|
return str(event.chat_id)
|
|
else:
|
|
return chat.username
|
|
|
|
|
|
def get_language(event: events.ChatAction, chat):
|
|
if event.is_group or event.is_channel:
|
|
return 'en'
|
|
return chat.lang_code
|
|
|
|
|
|
def is_banned(chat: ChatPb) -> bool:
|
|
return chat.ban_until is not None and datetime.utcnow().timestamp() < chat.ban_until
|
|
|
|
|
|
class ReadOnlyModeError(BaseError):
|
|
level = logging.WARNING
|
|
code = 'read_only_mode_error'
|
|
|
|
|
|
class BaseHandler(ABC):
|
|
# Is handler working in the groups
|
|
is_group_handler = False
|
|
# Is subscription to the `config['telegram']['related_channel'] required to use this handler
|
|
is_subscription_required_for_handler = False
|
|
# Telethon filter
|
|
filter = events.NewMessage(incoming=True)
|
|
should_reset_last_widget = True
|
|
# Raises StopPropagation in the end of handling. It means this handler would be the last one in chain
|
|
stop_propagation = True
|
|
# If set to True then read_only mode will disable handler
|
|
writing_handler = False
|
|
|
|
def __init__(self, application: TelegramApplication):
|
|
self.application = application
|
|
self.index_alias_to_resolver = {
|
|
'scimag': self.resolve_scimag,
|
|
'scitech': self.resolve_scitech,
|
|
}
|
|
self.short_index_alias_to_index_alias_dict = {
|
|
'a': 'scimag',
|
|
'b': 'scitech',
|
|
}
|
|
|
|
def generate_session_id(self) -> str:
|
|
return random_string(self.application.config['application']['session_id_length'])
|
|
|
|
def short_index_alias_to_index_alias(self, short_index_alias: str) -> str:
|
|
return self.short_index_alias_to_index_alias_dict[short_index_alias]
|
|
|
|
async def get_typed_document_pb(
|
|
self,
|
|
index_alias: str,
|
|
document_id: int,
|
|
request_context: RequestContext,
|
|
session_id: str,
|
|
position: int,
|
|
) -> TypedDocumentPb:
|
|
return await self.application.meta_api_client.get(
|
|
index_alias=index_alias,
|
|
document_id=document_id,
|
|
session_id=session_id,
|
|
position=position,
|
|
request_id=request_context.request_id,
|
|
user_id=str(request_context.chat.chat_id),
|
|
)
|
|
|
|
async def resolve_scimag(
|
|
self,
|
|
document_id: int,
|
|
position: int,
|
|
request_context: RequestContext,
|
|
session_id: str,
|
|
) -> ScimagView:
|
|
typed_document_pb = await self.get_typed_document_pb(
|
|
index_alias='scimag',
|
|
document_id=document_id,
|
|
position=position,
|
|
request_context=request_context,
|
|
session_id=session_id,
|
|
)
|
|
return ScimagView(document_pb=typed_document_pb.scimag)
|
|
|
|
async def resolve_scitech(
|
|
self,
|
|
document_id: int,
|
|
position: int,
|
|
request_context: RequestContext,
|
|
session_id: str,
|
|
) -> ScitechView:
|
|
typed_document_pb = await self.get_typed_document_pb(
|
|
index_alias='scitech',
|
|
document_id=document_id,
|
|
position=position,
|
|
request_context=request_context,
|
|
session_id=session_id,
|
|
)
|
|
search_response_duplicates = await self.application.meta_api_client.search(
|
|
index_aliases=('scitech',),
|
|
query=f'original_id:{document_id}',
|
|
page_size=16,
|
|
request_id=request_context.request_id,
|
|
session_id=session_id,
|
|
user_id=str(request_context.chat.chat_id),
|
|
)
|
|
duplicates = [
|
|
scored_document.typed_document.scitech
|
|
for scored_document in search_response_duplicates.scored_documents
|
|
]
|
|
return ScitechView(
|
|
document_pb=typed_document_pb.scitech,
|
|
duplicates=duplicates,
|
|
)
|
|
|
|
async def resolve_document(
|
|
self,
|
|
index_alias: str,
|
|
document_id: int,
|
|
position: int,
|
|
session_id: str,
|
|
request_context: RequestContext
|
|
) -> Union[ScimagView, ScitechView]:
|
|
if index_alias not in self.index_alias_to_resolver:
|
|
raise UnknownIndexAliasError(index_alias=index_alias)
|
|
|
|
resolver = self.index_alias_to_resolver[index_alias]
|
|
return await resolver(
|
|
document_id=document_id,
|
|
position=position,
|
|
request_context=request_context,
|
|
session_id=session_id,
|
|
)
|
|
|
|
def reset_last_widget(self, chat_id: int):
|
|
self.application.user_manager.last_widget[chat_id] = None
|
|
|
|
def register_for(self, telegram_client: TelegramClient):
|
|
telegram_client.add_event_handler(self._wrapped_handler, self.filter)
|
|
return self._wrapped_handler
|
|
|
|
async def _send_fail_response(self, event: events.ChatAction, request_context: RequestContext):
|
|
try:
|
|
await event.reply(
|
|
t('MAINTENANCE', language=request_context.chat.language).format(
|
|
maintenance_picture_url=self.application.config['application']['maintenance_picture_url'],
|
|
),
|
|
buttons=[close_button()]
|
|
)
|
|
except (ConnectionError, QueryIdInvalidError) as e:
|
|
request_context.error_log(e)
|
|
|
|
async def _put_chat(self, event: events.ChatAction, request_id: str):
|
|
event_chat = await event.get_chat()
|
|
username = get_username(event, event_chat)
|
|
language = get_language(event, event_chat)
|
|
if not self.application.idm_client:
|
|
return ChatPb(
|
|
chat_id=event.chat_id,
|
|
username=username,
|
|
language=language,
|
|
)
|
|
try:
|
|
chat = await self.application.idm_client.get_chat(
|
|
chat_id=event.chat_id,
|
|
request_id=request_id,
|
|
)
|
|
return chat
|
|
except AioRpcError as e:
|
|
if e.code() != StatusCode.NOT_FOUND:
|
|
raise
|
|
if self.application.config['application']['is_read_only_mode']:
|
|
raise ReadOnlyModeError()
|
|
if language not in {'en', 'ru'}:
|
|
language = 'en'
|
|
chat = await self.application.idm_client.create_chat(
|
|
chat_id=event.chat_id,
|
|
username=username,
|
|
language=language,
|
|
request_id=request_id,
|
|
)
|
|
return chat
|
|
|
|
async def _check_ban(self, event: events.ChatAction, request_context: RequestContext, chat: ChatPb):
|
|
if is_banned(chat):
|
|
if chat.ban_message is not None:
|
|
async with safe_execution(
|
|
request_context=request_context,
|
|
on_fail=lambda: self._send_fail_response(event, request_context),
|
|
):
|
|
await event.reply(t(
|
|
'BANNED',
|
|
language=chat.language
|
|
).format(
|
|
datetime=str(time.ctime(chat.ban_until)),
|
|
reason=chat.ban_message,
|
|
))
|
|
raise events.StopPropagation()
|
|
|
|
async def _check_maintenance(self, event: events.ChatAction):
|
|
if (
|
|
self.application.config['application']['is_maintenance_mode']
|
|
and event.chat_id not in self.application.config['application']['bypass_maintenance']
|
|
):
|
|
await event.reply(
|
|
t('UPGRADE_MAINTENANCE', language='en').format(
|
|
upgrade_maintenance_picture_url=self.application.config['application']
|
|
['upgrade_maintenance_picture_url']
|
|
),
|
|
)
|
|
raise events.StopPropagation()
|
|
|
|
async def _check_read_only(self, event: events.ChatAction):
|
|
if self.application.config['application']['is_read_only_mode']:
|
|
await event.reply(
|
|
t("READ_ONLY_MODE", language='en'),
|
|
)
|
|
raise events.StopPropagation()
|
|
|
|
async def is_subscribed(self, chat: ChatPb) -> bool:
|
|
if chat.chat_id < 0 or chat.created_at > time.time() - 10 * 60:
|
|
return True
|
|
try:
|
|
await self.application.telegram_client.get_permissions(
|
|
self.application.config['telegram']['related_channel'],
|
|
chat.chat_id,
|
|
)
|
|
except UserNotParticipantError:
|
|
return False
|
|
return True
|
|
|
|
async def _check_subscription(self, event: events.ChatAction, request_context: RequestContext, chat: ChatPb):
|
|
if (
|
|
self.application.config['application']['is_subscription_required']
|
|
and self.is_subscription_required_for_handler
|
|
and not await self.is_subscribed(chat)
|
|
):
|
|
async with safe_execution(
|
|
request_context=request_context,
|
|
on_fail=lambda: self._send_fail_response(event, request_context),
|
|
):
|
|
await event.reply(t(
|
|
'SUBSCRIBE_TO_CHANNEL',
|
|
language=chat.language
|
|
).format(related_channel=self.application.config['telegram']['related_channel']))
|
|
raise events.StopPropagation()
|
|
|
|
def _has_access(self, chat: ChatPb) -> bool:
|
|
return True
|
|
|
|
async def _process_chat(self, event: events.ChatAction, request_id: str):
|
|
try:
|
|
chat = await self._put_chat(event, request_id=request_id)
|
|
except (AioRpcError, BaseError) as e:
|
|
error_log(e)
|
|
event_chat = await event.get_chat()
|
|
username = get_username(event, event_chat)
|
|
chat = ChatPb(
|
|
chat_id=event.chat_id,
|
|
is_system_messaging_enabled=True,
|
|
is_discovery_enabled=True,
|
|
language='en',
|
|
username=username,
|
|
is_admin=False,
|
|
)
|
|
return chat
|
|
|
|
async def _wrapped_handler(self, event: events.ChatAction) -> None:
|
|
# Checking group permissions
|
|
if (event.is_group or event.is_channel) and not self.is_group_handler:
|
|
return
|
|
|
|
await self._check_maintenance(event=event)
|
|
await self._check_read_only(event=event)
|
|
|
|
request_id = RequestContext.generate_request_id(self.application.config['application']['request_id_length'])
|
|
chat = await self._process_chat(event=event, request_id=request_id)
|
|
|
|
request_context = RequestContext(
|
|
bot_name=self.application.config['telegram']['bot_name'],
|
|
chat=chat,
|
|
request_id=request_id,
|
|
request_id_length=self.application.config['application']['request_id_length'],
|
|
)
|
|
|
|
if not self._has_access(chat):
|
|
return
|
|
|
|
await self._check_subscription(event=event, request_context=request_context, chat=chat)
|
|
await self._check_ban(event=event, request_context=request_context, chat=chat)
|
|
|
|
if self.should_reset_last_widget:
|
|
self.reset_last_widget(request_context.chat.chat_id)
|
|
|
|
async with safe_execution(
|
|
request_context=request_context,
|
|
on_fail=lambda: self._send_fail_response(event, request_context),
|
|
):
|
|
await self.handler(
|
|
event,
|
|
request_context=request_context,
|
|
)
|
|
if self.stop_propagation:
|
|
raise events.StopPropagation()
|
|
|
|
async def handler(self, event: events.ChatAction, request_context: RequestContext):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class BaseCallbackQueryHandler(BaseHandler, ABC):
|
|
async def _send_fail_response(self, event, request_context: RequestContext):
|
|
try:
|
|
await event.answer(t('MAINTENANCE_WO_PIC', language=request_context.chat.language))
|
|
except (ConnectionError, QueryIdInvalidError) as e:
|
|
request_context.error_log(e)
|