363 lines
14 KiB
Python
363 lines
14 KiB
Python
import asyncio
|
|
import dataclasses
|
|
import logging
|
|
import sys
|
|
from contextlib import suppress
|
|
from datetime import timedelta
|
|
from timeit import default_timer
|
|
from typing import (
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
import orjson as json
|
|
from aiosumma.eval_scorer_builder import EvalScorerBuilder
|
|
from aiosumma.parser.errors import ParseError
|
|
from aiosumma.processor import ProcessedQuery
|
|
from cachetools import TTLCache
|
|
from grpc import StatusCode
|
|
from izihawa_utils.exceptions import NeedRetryError
|
|
from izihawa_utils.pb_to_json import MessageToDict
|
|
from library.aiogrpctools.base import aiogrpc_request_wrapper
|
|
from nexus.meta_api.mergers import (
|
|
AggregationMerger,
|
|
CountMerger,
|
|
ReservoirSamplingMerger,
|
|
TopDocsMerger,
|
|
)
|
|
from nexus.meta_api.proto import search_service_pb2 as meta_search_service_pb2
|
|
from nexus.meta_api.proto.search_service_pb2_grpc import (
|
|
SearchServicer,
|
|
add_SearchServicer_to_server,
|
|
)
|
|
from nexus.meta_api.services.base import BaseService
|
|
from nexus.models.proto import (
|
|
operation_pb2,
|
|
scimag_pb2,
|
|
typed_document_pb2,
|
|
)
|
|
from summa.proto import search_service_pb2
|
|
from tenacity import (
|
|
AsyncRetrying,
|
|
RetryError,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_fixed,
|
|
)
|
|
|
|
|
|
def to_bool(b: Union[str, None, bool, int]):
|
|
if isinstance(b, str):
|
|
return b == '1'
|
|
if b is None:
|
|
return False
|
|
return bool(b)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SearchRequest:
|
|
index_alias: str
|
|
query: ProcessedQuery
|
|
collectors: List[Dict]
|
|
|
|
def cache_key(self):
|
|
return (
|
|
self.index_alias,
|
|
str(self.query),
|
|
json.dumps(
|
|
[MessageToDict(collector, preserving_proto_field_name=True) for collector in self.collectors],
|
|
option=json.OPT_SORT_KEYS,
|
|
),
|
|
)
|
|
|
|
|
|
class SearchService(SearchServicer, BaseService):
|
|
snippets = {
|
|
'scimag': {
|
|
'title': 1024,
|
|
'abstract': 100,
|
|
},
|
|
'scitech': {
|
|
'title': 1024,
|
|
'description': 100,
|
|
}
|
|
}
|
|
|
|
def __init__(self, application, stat_provider, summa_client, query_preprocessor, query_transformers, learn_logger=None):
|
|
super().__init__(
|
|
application=application,
|
|
stat_provider=stat_provider,
|
|
summa_client=summa_client,
|
|
)
|
|
self.query_cache = TTLCache(maxsize=1024 * 4, ttl=300)
|
|
self.query_preprocessor = query_preprocessor
|
|
self.query_transformers = query_transformers
|
|
self.learn_logger = learn_logger
|
|
|
|
async def start(self):
|
|
add_SearchServicer_to_server(self, self.application.server)
|
|
|
|
def merge_search_responses(self, search_responses, collector_descriptors: List[str]):
|
|
collector_outputs = []
|
|
for i, collector_descriptor in enumerate(collector_descriptors):
|
|
match collector_descriptor:
|
|
case 'aggregation':
|
|
merger = AggregationMerger([
|
|
search_response.collector_outputs[i].aggregation
|
|
for search_response in search_responses if search_response.collector_outputs
|
|
])
|
|
case 'count':
|
|
merger = CountMerger([
|
|
search_response.collector_outputs[i].count
|
|
for search_response in search_responses if search_response.collector_outputs
|
|
])
|
|
case 'top_docs':
|
|
merger = TopDocsMerger([
|
|
search_response.collector_outputs[i].top_docs
|
|
for search_response in search_responses if search_response.collector_outputs
|
|
])
|
|
case 'reservoir_sampling':
|
|
merger = ReservoirSamplingMerger([
|
|
search_response.collector_outputs[i].reservoir_sampling
|
|
for search_response in search_responses if search_response.collector_outputs
|
|
])
|
|
case _:
|
|
raise RuntimeError("Unsupported collector")
|
|
collector_outputs.append(merger.merge())
|
|
return meta_search_service_pb2.MetaSearchResponse(collector_outputs=collector_outputs)
|
|
|
|
async def check_if_need_new_documents_by_dois(self, requested_dois, scored_documents, should_request):
|
|
if requested_dois:
|
|
found_dois = set([
|
|
getattr(
|
|
scored_document.typed_document,
|
|
scored_document.typed_document.WhichOneof('document')
|
|
).doi
|
|
for scored_document in scored_documents
|
|
])
|
|
if len(found_dois) < len(requested_dois):
|
|
if should_request:
|
|
for doi in requested_dois:
|
|
if doi not in found_dois:
|
|
await self.request_doi_delivery(doi=doi)
|
|
raise NeedRetryError()
|
|
|
|
def resolve_index_aliases(self, request_index_aliases, processed_query):
|
|
"""
|
|
Derives requested indices through request and query
|
|
"""
|
|
index_aliases = set([index_alias for index_alias in request_index_aliases])
|
|
index_aliases_from_query = processed_query.context.index_aliases or index_aliases
|
|
return tuple(sorted([index_alias for index_alias in index_aliases_from_query if index_alias in index_aliases]))
|
|
|
|
def scorer(self, processed_query, index_alias):
|
|
if processed_query.context.order_by:
|
|
return search_service_pb2.Scorer(order_by=processed_query.context.order_by[0])
|
|
|
|
if processed_query.is_empty():
|
|
return None
|
|
|
|
eval_scorer_builder = EvalScorerBuilder()
|
|
if index_alias == 'scimag':
|
|
eval_scorer_builder.add_exp_decay(
|
|
field_name='issued_at',
|
|
origin=(
|
|
processed_query.context.query_point_of_time
|
|
- processed_query.context.query_point_of_time % 86400
|
|
),
|
|
scale=timedelta(days=365.25 * 14),
|
|
offset=timedelta(days=30),
|
|
decay=0.85,
|
|
)
|
|
eval_scorer_builder.add_fastsigm('page_rank + 1', 0.45)
|
|
elif index_alias == 'scitech':
|
|
eval_scorer_builder.ops.append('0.7235')
|
|
return eval_scorer_builder.build()
|
|
|
|
async def process_query(self, query, languages, context):
|
|
try:
|
|
return self.query_preprocessor.process(query, languages)
|
|
except ParseError:
|
|
return await context.abort(StatusCode.INVALID_ARGUMENT, 'parse_error')
|
|
|
|
async def base_search(
|
|
self,
|
|
search_requests: List[SearchRequest],
|
|
collector_descriptors: List[str],
|
|
request_id: str,
|
|
session_id: str,
|
|
user_id: Optional[str] = None,
|
|
skip_cache_loading: Optional[Union[bool, str]] = None,
|
|
skip_cache_saving: Optional[Union[bool, str]] = None,
|
|
original_query: Optional[str] = None,
|
|
query_tags: Optional[List[str]] = None,
|
|
):
|
|
start = default_timer()
|
|
|
|
skip_cache_saving = to_bool(skip_cache_saving)
|
|
skip_cache_loading = to_bool(skip_cache_loading)
|
|
|
|
cache_key = tuple(search_request.cache_key() for search_request in search_requests)
|
|
meta_search_response = self.query_cache.get(cache_key)
|
|
cache_hit = bool(meta_search_response)
|
|
|
|
if not cache_hit or skip_cache_loading:
|
|
requests = []
|
|
for search_request in search_requests:
|
|
requests.append(
|
|
self.summa_client.search(
|
|
index_alias=search_request.index_alias,
|
|
query=search_request.query,
|
|
collectors=search_request.collectors,
|
|
request_id=request_id,
|
|
session_id=session_id,
|
|
ignore_not_found=True,
|
|
)
|
|
)
|
|
search_responses = await asyncio.gather(*requests)
|
|
meta_search_response = self.merge_search_responses(search_responses, collector_descriptors)
|
|
if not skip_cache_saving:
|
|
self.query_cache[cache_key] = meta_search_response
|
|
|
|
logging.getLogger('query').info({
|
|
'action': 'request',
|
|
'cache_hit': cache_hit,
|
|
'duration': default_timer() - start,
|
|
'mode': 'search',
|
|
'query': original_query,
|
|
'request_id': request_id,
|
|
'session_id': session_id,
|
|
'query_tags': query_tags,
|
|
'user_id': user_id,
|
|
})
|
|
|
|
return meta_search_response
|
|
|
|
@aiogrpc_request_wrapper(log=False)
|
|
async def meta_search(self, request, context, metadata):
|
|
processed_query = await self.process_query(
|
|
query=request.query,
|
|
languages=dict(request.languages),
|
|
context=context,
|
|
)
|
|
index_aliases = self.resolve_index_aliases(
|
|
request_index_aliases=request.index_aliases,
|
|
processed_query=processed_query,
|
|
)
|
|
search_requests = [
|
|
SearchRequest(
|
|
index_alias=index_alias,
|
|
query=self.query_transformers[index_alias].apply_tree_transformers(processed_query).to_summa_query(),
|
|
collectors=request.collectors,
|
|
) for index_alias in index_aliases
|
|
]
|
|
collector_descriptors = [collector.WhichOneof('collector') for collector in request.collectors]
|
|
|
|
return await self.base_search(
|
|
search_requests=search_requests,
|
|
collector_descriptors=collector_descriptors,
|
|
request_id=metadata['request-id'],
|
|
session_id=metadata['session-id'],
|
|
user_id=metadata.get('user-id'),
|
|
skip_cache_loading=metadata.get('skip-cache-loading'),
|
|
skip_cache_saving=metadata.get('skip-cache-saving'),
|
|
original_query=request.query,
|
|
query_tags=[tag for tag in request.query_tags],
|
|
)
|
|
|
|
@aiogrpc_request_wrapper(log=False)
|
|
async def search(self, request, context, metadata):
|
|
preprocessed_query = await self.process_query(query=request.query, languages=request.language, context=context)
|
|
logging.getLogger('debug').info({
|
|
'action': 'preprocess_query',
|
|
'preprocessed_query': str(preprocessed_query),
|
|
})
|
|
index_aliases = self.resolve_index_aliases(
|
|
request_index_aliases=request.index_aliases,
|
|
processed_query=preprocessed_query,
|
|
)
|
|
|
|
page_size = request.page_size or 5
|
|
left_offset = request.page * page_size
|
|
right_offset = left_offset + page_size
|
|
|
|
search_requests = []
|
|
processed_queries = {}
|
|
for index_alias in index_aliases:
|
|
processed_queries[index_alias] = self.query_transformers[index_alias].apply_tree_transformers(preprocessed_query)
|
|
logging.getLogger('debug').info({
|
|
'action': 'process_query',
|
|
'index_alias': index_alias,
|
|
'processed_query': str(processed_queries[index_alias]),
|
|
'order_by': processed_queries[index_alias].context.order_by,
|
|
'has_invalid_fields': processed_queries[index_alias].context.has_invalid_fields,
|
|
})
|
|
search_requests.append(
|
|
SearchRequest(
|
|
index_alias=index_alias,
|
|
query=processed_queries[index_alias].to_summa_query(),
|
|
collectors=[
|
|
search_service_pb2.Collector(
|
|
top_docs=search_service_pb2.TopDocsCollector(
|
|
limit=50,
|
|
scorer=self.scorer(processed_queries[index_alias], index_alias),
|
|
snippets=self.snippets[index_alias],
|
|
explain=processed_queries[index_alias].context.explain,
|
|
)
|
|
),
|
|
search_service_pb2.Collector(count=search_service_pb2.CountCollector())
|
|
],
|
|
)
|
|
)
|
|
|
|
with suppress(RetryError):
|
|
async for attempt in AsyncRetrying(
|
|
retry=retry_if_exception_type(NeedRetryError),
|
|
wait=wait_fixed(10),
|
|
stop=stop_after_attempt(6)
|
|
):
|
|
with attempt:
|
|
meta_search_response = await self.base_search(
|
|
search_requests=search_requests,
|
|
collector_descriptors=['top_docs', 'count'],
|
|
request_id=metadata['request-id'],
|
|
session_id=metadata['session-id'],
|
|
user_id=metadata.get('user-id'),
|
|
skip_cache_loading=attempt.retry_state.attempt_number > 1 or metadata.get('skip-cache-loading'),
|
|
skip_cache_saving=metadata.get('skip-cache-saving'),
|
|
original_query=request.query,
|
|
query_tags=[tag for tag in request.query_tags],
|
|
)
|
|
new_scored_documents = self.cast_top_docs_collector(
|
|
meta_search_response.collector_outputs[0].top_docs.scored_documents,
|
|
)
|
|
has_next = len(new_scored_documents) > right_offset
|
|
if 'scimag' in processed_queries:
|
|
await self.check_if_need_new_documents_by_dois(
|
|
requested_dois=processed_queries['scimag'].context.dois,
|
|
scored_documents=new_scored_documents,
|
|
should_request=attempt.retry_state.attempt_number == 1
|
|
)
|
|
|
|
search_response_pb = meta_search_service_pb2.SearchResponse(
|
|
scored_documents=new_scored_documents[left_offset:right_offset],
|
|
has_next=has_next,
|
|
count=meta_search_response.collector_outputs[1].count.count,
|
|
query_language=preprocessed_query.context.query_language,
|
|
)
|
|
|
|
return search_response_pb
|
|
|
|
async def request_doi_delivery(self, doi):
|
|
document_operation = operation_pb2.DocumentOperation(
|
|
update_document=operation_pb2.UpdateDocument(
|
|
should_fill_from_external_source=True,
|
|
full_text_index=True,
|
|
full_text_index_commit=True,
|
|
typed_document=typed_document_pb2.TypedDocument(scimag=scimag_pb2.Scimag(doi=doi)),
|
|
),
|
|
)
|
|
self.operation_logger.info(MessageToDict(document_operation, preserving_proto_field_name=True))
|