the-superpirate dd23846059 - [nexus] Refactoring
- [nexus] Switch bot
  - [bot] Added extra receivers functionality

GitOrigin-RevId: 68fc32d3e79ff411758f54f435fe8680fc42dead
2022-03-28 17:42:18 +03:00

238 lines
9.4 KiB
Python

import asyncio
import json
import logging
from contextlib import suppress
from timeit import default_timer
from aiokit import AioThing
from cachetools import TTLCache
from grpc import StatusCode
from izihawa_utils.exceptions import NeedRetryError
from izihawa_utils.pb_to_json import MessageToDict
from izihawa_utils.text import camel_to_snake
from library.aiogrpctools.base import (
BaseService,
aiogrpc_request_wrapper,
)
from nexus.meta_api.proto.search_service_pb2 import \
ScoredDocument as ScoredDocumentPb
from nexus.meta_api.proto.search_service_pb2 import \
SearchResponse as SearchResponsePb
from nexus.meta_api.proto.search_service_pb2_grpc import (
SearchServicer,
add_SearchServicer_to_server,
)
from nexus.meta_api.query_extensions import (
ClassicQueryProcessor,
QueryClass,
)
from nexus.meta_api.rescorers import ClassicRescorer
from nexus.models.proto.operation_pb2 import \
DocumentOperation as DocumentOperationPb
from nexus.models.proto.operation_pb2 import UpdateDocument as UpdateDocumentPb
from nexus.models.proto.scimag_pb2 import Scimag as ScimagPb
from nexus.models.proto.typed_document_pb2 import \
TypedDocument as TypedDocumentPb
from nexus.nlptools.utils import despace_full
from nexus.views.telegram.registry import pb_registry
from tenacity import (
AsyncRetrying,
RetryError,
retry_if_exception_type,
stop_after_attempt,
wait_fixed,
)
class ClassicSearcher(BaseService):
page_sizes = {
'scimag': 100,
'scitech': 100,
}
def __init__(self, summa_client, query_processor, rescorer, stat_provider):
super().__init__(service_name='meta_api')
self.summa_client = summa_client
self.operation_logger = logging.getLogger('operation')
self.class_name = camel_to_snake(self.__class__.__name__)
self.query_cache = TTLCache(maxsize=1024 * 4, ttl=300)
self.query_processor = query_processor
self.rescorer = rescorer
self.stat_provider = stat_provider
async def processed_query_hook(self, processed_query, context):
if processed_query['class'] == QueryClass.URL:
await context.abort(StatusCode.INVALID_ARGUMENT, 'url_query_error')
return processed_query
def merge_search_responses(self, search_responses):
if not search_responses:
return
elif len(search_responses) == 1:
return search_responses[0]
return SearchResponsePb(
scored_documents=[
scored_document
for search_response in search_responses
for scored_document in search_response.scored_documents
],
has_next=any([search_response.has_next for search_response in search_responses]),
)
def cast_search_response(self, name, search_response):
scored_documents_pb = []
for scored_document in search_response['scored_documents']:
document = json.loads(scored_document['document'])
for field in document:
if field in {'authors', 'ipfs_multihashes', 'isbns', 'issns', 'references', 'tags'}:
continue
document[field] = document[field][0]
original_id = (
document.get('original_id')
or document['id']
)
download_stats = self.stat_provider.get_download_stats(original_id)
if download_stats and download_stats.downloads_count:
document['downloads_count'] = download_stats.downloads_count
scored_documents_pb.append(ScoredDocumentPb(
position=scored_document['position'],
score=scored_document['score'],
typed_document=TypedDocumentPb(
**{name: pb_registry[name](**document)},
)
))
return SearchResponsePb(
scored_documents=scored_documents_pb,
has_next=search_response['has_next'],
)
@aiogrpc_request_wrapper()
async def search(self, request, context, metadata):
start = default_timer()
processed_query = None
cache_hit = True
page_size = request.page_size or 5
index_aliases = tuple(sorted([index_alias for index_alias in request.index_aliases]))
user_id = metadata['user-id']
if (
(user_id, request.language, index_aliases, request.query) not in self.query_cache
or len(self.query_cache[(user_id, request.language, index_aliases, request.query)].scored_documents) == 0
):
cache_hit = False
query = despace_full(request.query)
processed_query = self.query_processor.process(query, request.language)
processed_query = await self.processed_query_hook(processed_query, context)
with suppress(RetryError):
async for attempt in AsyncRetrying(
retry=retry_if_exception_type(NeedRetryError),
wait=wait_fixed(15),
stop=stop_after_attempt(2)
):
with attempt:
requests = []
for index_alias in index_aliases:
requests.append(
self.summa_client.search(
index_alias=index_alias,
query=processed_query['query'],
offset=0,
limit=self.page_sizes[index_alias],
request_id=metadata['request-id'],
)
)
search_responses = [
MessageToDict(
search_response,
preserving_proto_field_name=True,
including_default_value_fields=True,
) for search_response in await asyncio.gather(*requests)
]
search_responses_pb = [
self.cast_search_response(name, search_response)
for (name, search_response) in zip(index_aliases, search_responses)
]
search_response_pb = self.merge_search_responses(search_responses_pb)
if len(search_response_pb.scored_documents) == 0 and processed_query['class'] == QueryClass.DOI:
if attempt.retry_state.attempt_number == 1:
await self.request_doi_delivery(doi=processed_query['doi'])
raise NeedRetryError()
rescored_documents = await self.rescorer.rescore(
scored_documents=search_response_pb.scored_documents,
query=query,
session_id=metadata['session-id'],
language=request.language,
)
search_response_pb = SearchResponsePb(
scored_documents=rescored_documents,
has_next=search_response_pb.has_next,
)
self.query_cache[(user_id, request.language, index_aliases, request.query)] = search_response_pb
logging.getLogger('query').info({
'action': 'request',
'cache_hit': cache_hit,
'duration': default_timer() - start,
'index_aliases': index_aliases,
'mode': 'search',
'page': request.page,
'page_size': page_size,
'processed_query': processed_query['query'] if processed_query else None,
'query': request.query,
'query_class': processed_query['class'].value if processed_query else None,
'request_id': metadata['request-id'],
'session_id': metadata['session-id'],
'user_id': user_id,
})
scored_documents = self.query_cache[(user_id, request.language, index_aliases, request.query)].scored_documents
left_offset = request.page * page_size
right_offset = left_offset + page_size
has_next = len(scored_documents) > right_offset
search_response_pb = SearchResponsePb(
scored_documents=scored_documents[left_offset:right_offset],
has_next=has_next,
)
return search_response_pb
async def request_doi_delivery(self, doi):
document_operation = DocumentOperationPb(
update_document=UpdateDocumentPb(
commit=True,
reindex=True,
should_fill_from_external_source=True,
typed_document=TypedDocumentPb(scimag=ScimagPb(doi=doi)),
),
)
self.operation_logger.info(MessageToDict(document_operation, preserving_proto_field_name=True))
class SearchService(SearchServicer, AioThing):
def __init__(self, server, summa_client, stat_provider, learn_logger=None):
super().__init__()
self.server = server
self.searcher = ClassicSearcher(
summa_client=summa_client,
query_processor=ClassicQueryProcessor(),
rescorer=ClassicRescorer(
learn_logger=learn_logger,
),
stat_provider=stat_provider,
)
self.starts.append(self.searcher)
async def start(self):
add_SearchServicer_to_server(self, self.server)
async def search(self, request, context):
return await self.searcher.search(request, context)