diff --git a/trace_replay/trace_replay.cc b/trace_replay/trace_replay.cc index 5fd529568..6171d91ec 100644 --- a/trace_replay/trace_replay.cc +++ b/trace_replay/trace_replay.cc @@ -607,6 +607,25 @@ Status Replayer::Replay() { single_iter->SeekForPrev(iter_payload.iter_key); ops++; delete single_iter; + } else if (trace.type == kTraceMultiGet) { + MultiGetPayload multiget_payload; + assert(trace_file_version_ >= 2); + TracerHelper::DecodeMultiGetPayload(&trace, &multiget_payload); + std::vector v_cfd; + std::vector keys; + assert(multiget_payload.cf_ids.size() == + multiget_payload.multiget_keys.size()); + for (size_t i = 0; i < multiget_payload.cf_ids.size(); i++) { + assert(i < multiget_payload.cf_ids.size() && + i < multiget_payload.multiget_keys.size()); + if (cf_map_.find(multiget_payload.cf_ids[i]) == cf_map_.end()) { + return Status::Corruption("Invalid Column Family ID."); + } + v_cfd.push_back(cf_map_[multiget_payload.cf_ids[i]]); + keys.push_back(Slice(multiget_payload.multiget_keys[i])); + } + std::vector values; + std::vector ss = db_->MultiGet(roptions, v_cfd, keys, &values); } else if (trace.type == kTraceEnd) { // Do nothing for now. // TODO: Add some validations later. @@ -685,6 +704,10 @@ Status Replayer::MultiThreadReplay(uint32_t threads_num) { thread_pool.Schedule(&Replayer::BGWorkIterSeekForPrev, ra.release(), nullptr, nullptr); ops++; + } else if (ra->trace_entry.type == kTraceMultiGet) { + thread_pool.Schedule(&Replayer::BGWorkMultiGet, ra.release(), nullptr, + nullptr); + ops++; } else if (ra->trace_entry.type == kTraceEnd) { // Do nothing for now. // TODO: Add some validations later. @@ -861,4 +884,32 @@ void Replayer::BGWorkIterSeekForPrev(void* arg) { return; } +void Replayer::BGWorkMultiGet(void* arg) { + std::unique_ptr ra( + reinterpret_cast(arg)); + assert(ra != nullptr); + auto cf_map = static_cast*>( + ra->cf_map); + MultiGetPayload multiget_payload; + if (ra->trace_file_version < 2) { + return; + } + TracerHelper::DecodeMultiGetPayload(&(ra->trace_entry), &multiget_payload); + std::vector v_cfd; + std::vector keys; + if (multiget_payload.cf_ids.size() != multiget_payload.multiget_keys.size()) { + return; + } + for (size_t i = 0; i < multiget_payload.cf_ids.size(); i++) { + if (cf_map->find(multiget_payload.cf_ids[i]) == cf_map->end()) { + return; + } + v_cfd.push_back((*cf_map)[multiget_payload.cf_ids[i]]); + keys.push_back(Slice(multiget_payload.multiget_keys[i])); + } + std::vector values; + std::vector ss = ra->db->MultiGet(ra->roptions, v_cfd, keys, &values); + return; +} + } // namespace ROCKSDB_NAMESPACE diff --git a/trace_replay/trace_replay.h b/trace_replay/trace_replay.h index d3ad2d799..d10bc1b46 100644 --- a/trace_replay/trace_replay.h +++ b/trace_replay/trace_replay.h @@ -268,6 +268,10 @@ class Replayer { // (SeekForPrev) based on the trace records. static void BGWorkIterSeekForPrev(void* arg); + // The background function for MultiThreadReplay to execute MultiGet based on + // the trace records + static void BGWorkMultiGet(void* arg); + DBImpl* db_; Env* env_; std::unique_ptr trace_reader_;