238 lines
9.5 KiB
Python
Raw Normal View History

import asyncio
import logging
from contextlib import suppress
from timeit import default_timer
from aiokit import AioThing
from cachetools import TTLCache
from google.protobuf.json_format import MessageToDict
from grpc import StatusCode
from izihawa_utils.exceptions import NeedRetryError
from izihawa_utils.text import camel_to_snake
from library.aiogrpctools.base import aiogrpc_request_wrapper
from nexus.meta_api.proto.search_service_pb2 import \
ScoredDocument as ScoredDocumentPb
from nexus.meta_api.proto.search_service_pb2 import \
SearchResponse as SearchResponsePb
from nexus.meta_api.proto.search_service_pb2_grpc import (
SearchServicer,
add_SearchServicer_to_server,
)
from nexus.meta_api.query_extensionner import (
ClassicQueryProcessor,
QueryClass,
)
from nexus.meta_api.rescorers import ClassicRescorer
from nexus.models.proto.operation_pb2 import \
DocumentOperation as DocumentOperationPb
from nexus.models.proto.operation_pb2 import UpdateDocument as UpdateDocumentPb
from nexus.models.proto.scimag_pb2 import Scimag as ScimagPb
from nexus.models.proto.typed_document_pb2 import \
TypedDocument as TypedDocumentPb
from nexus.nlptools.utils import despace_full
from nexus.views.telegram.registry import pb_registry
from tenacity import (
AsyncRetrying,
RetryError,
retry_if_exception_type,
stop_after_attempt,
wait_fixed,
)
from .base import BaseService
class Searcher(BaseService):
page_sizes = {
'scimag': 100,
'scitech': 100,
}
def __init__(self, summa_client, query_processor, rescorer, stat_provider):
super().__init__(service_name='meta_api')
self.summa_client = summa_client
self.operation_logger = logging.getLogger('operation')
self.class_name = camel_to_snake(self.__class__.__name__)
self.query_cache = TTLCache(maxsize=1024 * 4, ttl=300)
self.query_processor = query_processor
self.rescorer = rescorer
self.stat_provider = stat_provider
async def processed_response_hook(self, processor_response, context):
return processor_response
async def post_search_hook(self, search_response, processor_response, request, context,
metadata, retry_state):
return search_response
def merge_search_responses(self, search_responses):
if not search_responses:
return
elif len(search_responses) == 1:
return search_responses[0]
return dict(
scored_documents=[
scored_document
for search_response in search_responses
for scored_document in search_response['scored_documents']
],
has_next=any([search_response['has_next'] for search_response in search_responses]),
)
def cast_search_response(self, search_response):
scored_documents_pb = []
for scored_document in search_response['scored_documents']:
document_pb = pb_registry[scored_document['schema']](**scored_document['document'])
scored_documents_pb.append(ScoredDocumentPb(
position=scored_document['position'],
score=scored_document['score'],
typed_document=TypedDocumentPb(
**{scored_document['schema']: document_pb},
)
))
return SearchResponsePb(
scored_documents=scored_documents_pb,
has_next=search_response['has_next'],
)
@aiogrpc_request_wrapper()
async def search(self, request, context, metadata):
start = default_timer()
processor_response = None
cache_hit = True
page_size = request.page_size or 5
schemas = tuple(sorted([schema for schema in request.schemas]))
user_id = metadata['user-id']
if (
(user_id, request.language, schemas, request.query) not in self.query_cache
or len(self.query_cache[(user_id, request.language, schemas, request.query)].scored_documents) == 0
):
cache_hit = False
query = despace_full(request.query)
processor_response = self.query_processor.process(query, request.language)
processor_response = await self.processed_response_hook(processor_response, context)
with suppress(RetryError):
async for attempt in AsyncRetrying(
retry=retry_if_exception_type(NeedRetryError),
wait=wait_fixed(10),
stop=stop_after_attempt(3)
):
with attempt:
requests = []
for schema in schemas:
requests.append(
self.summa_client.search(
schema=schema,
query=processor_response['query'],
page=0,
page_size=self.page_sizes[schema],
request_id=metadata['request-id'],
)
)
search_response = self.merge_search_responses(await asyncio.gather(*requests))
search_response = await self.post_search_hook(
search_response,
processor_response=processor_response,
request=request,
context=context,
metadata=metadata,
retry_state=attempt.retry_state
)
rescored_documents = await self.rescorer.rescore(
scored_documents=search_response['scored_documents'],
query=query,
session_id=metadata['session-id'],
language=request.language,
)
search_response['scored_documents'] = rescored_documents
search_response_pb = self.cast_search_response(search_response)
self.query_cache[(user_id, request.language, schemas, request.query)] = search_response_pb
logging.getLogger('query').info({
'action': 'request',
'cache_hit': cache_hit,
'duration': default_timer() - start,
'mode': 'search',
'page': request.page,
'page_size': page_size,
'processed_query': processor_response['query'] if processor_response else None,
'query': request.query,
'query_class': processor_response['class'].value if processor_response else None,
'request_id': metadata['request-id'],
'schemas': schemas,
'session_id': metadata['session-id'],
'user_id': user_id,
})
scored_documents = self.query_cache[(user_id, request.language, schemas, request.query)].scored_documents
left_offset = request.page * page_size
right_offset = left_offset + page_size
has_next = len(scored_documents) > right_offset
search_response_pb = SearchResponsePb(
scored_documents=scored_documents[left_offset:right_offset],
has_next=has_next,
)
return search_response_pb
class ClassicSearcher(Searcher):
async def processed_response_hook(self, processor_response, context):
if processor_response['class'] == QueryClass.URL:
await context.abort(StatusCode.INVALID_ARGUMENT, 'url_query_error')
return processor_response
async def post_search_hook(self, search_response, processor_response, request, context, metadata,
retry_state):
if len(search_response['scored_documents']) == 0 and processor_response['class'] == QueryClass.DOI:
if retry_state.attempt_number == 1:
await self.request_doi_delivery(doi=processor_response['doi'])
raise NeedRetryError()
for scored_document in search_response['scored_documents']:
original_id = (
scored_document['document'].get('original_id')
or scored_document['document']['id']
)
download_stats = self.stat_provider.get_download_stats(original_id)
if download_stats and download_stats.downloads_count:
scored_document['document']['downloads_count'] = download_stats.downloads_count
return search_response
async def request_doi_delivery(self, doi):
document_operation = DocumentOperationPb(
update_document=UpdateDocumentPb(
commit=True,
reindex=True,
should_fill_from_external_source=True,
typed_document=TypedDocumentPb(scimag=ScimagPb(doi=doi)),
),
)
self.operation_logger.info(MessageToDict(document_operation))
class SearchService(SearchServicer, AioThing):
def __init__(self, server, summa_client, stat_provider, learn_logger=None):
super().__init__()
self.server = server
self.searcher = ClassicSearcher(
summa_client=summa_client,
query_processor=ClassicQueryProcessor(),
rescorer=ClassicRescorer(
learn_logger=learn_logger,
),
stat_provider=stat_provider,
)
self.starts.append(self.searcher)
async def start(self):
add_SearchServicer_to_server(self, self.server)
async def search(self, request, context):
return await self.searcher.search(request, context)