68 lines
2.3 KiB
Python
Raw Normal View History

import heapq
from typing import List
from summa.proto import search_service_pb2
class TopDocsIterator:
def __init__(self, top_docs_collector: search_service_pb2.TopDocsCollectorOutput):
self.top_docs_collector = top_docs_collector
self._current = 0
def __lt__(self, other: 'TopDocsIterator'):
self_score = self.current().score
other_score = other.current().score
self_score = getattr(self_score, self_score.WhichOneof('score'))
other_score = getattr(other_score, other_score.WhichOneof('score'))
return self_score > other_score
def __iter__(self):
return self
def __next__(self) -> search_service_pb2.ScoredDocument:
if self.has_any():
item = self.current()
self._current += 1
return item
raise StopIteration
def current(self):
return self.top_docs_collector.scored_documents[self._current]
def has_next(self) -> bool:
return self.top_docs_collector.has_next
def has_any(self) -> bool:
return self._current < len(self.top_docs_collector.scored_documents)
class TopDocsMerger:
def __init__(self, top_docs_collectors: List[search_service_pb2.TopDocsCollectorOutput]):
self.top_docs_heap = []
for top_docs_collector in top_docs_collectors:
top_docs_iterator = TopDocsIterator(top_docs_collector)
if top_docs_iterator.has_any():
self.top_docs_heap.append(top_docs_iterator)
heapq.heapify(self.top_docs_heap)
def merge(self) -> search_service_pb2.CollectorOutput:
scored_documents = []
has_next = any([top_docs_iterator for top_docs_iterator in self.top_docs_heap])
position = 0
while self.top_docs_heap:
largest_top_docs_iterator = heapq.heappop(self.top_docs_heap)
largest_item = next(largest_top_docs_iterator)
largest_item.position = position
scored_documents.append(largest_item)
position += 1
if largest_top_docs_iterator.has_any():
heapq.heappush(self.top_docs_heap, largest_top_docs_iterator)
return search_service_pb2.CollectorOutput(
top_docs=search_service_pb2.TopDocsCollectorOutput(
has_next=has_next,
scored_documents=scored_documents,
)
)