diff --git a/java/RocksDBSample.java b/java/RocksDBSample.java index 2005b15d9..2e27e9377 100644 --- a/java/RocksDBSample.java +++ b/java/RocksDBSample.java @@ -4,6 +4,9 @@ // of patent rights can be found in the PATENTS file in the same directory. import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.ArrayList; import org.rocksdb.*; import org.rocksdb.util.SizeUnit; import java.io.IOException; @@ -217,6 +220,25 @@ public class RocksDBSample { iterator.close(); System.out.println("iterator tests passed."); + + iterator = db.newIterator(); + List keys = new ArrayList(); + for (iterator.seekToLast(); iterator.isValid(); iterator.prev()) { + keys.add(iterator.key()); + } + iterator.close(); + + Map values = db.multiGet(keys); + assert(values.size() == keys.size()); + for(byte[] value1 : values.values()) { + assert(value1 != null); + } + + values = db.multiGet(new ReadOptions(), keys); + assert(values.size() == keys.size()); + for(byte[] value1 : values.values()) { + assert(value1 != null); + } } catch (RocksDBException e) { System.err.println(e); } diff --git a/java/org/rocksdb/RocksDB.java b/java/org/rocksdb/RocksDB.java index 0574851fa..9ff45707f 100644 --- a/java/org/rocksdb/RocksDB.java +++ b/java/org/rocksdb/RocksDB.java @@ -5,6 +5,9 @@ package org.rocksdb; +import java.util.List; +import java.util.Map; +import java.util.HashMap; import java.io.Closeable; import java.io.IOException; @@ -147,6 +150,7 @@ public class RocksDB { * returned if the specified key is not found. * * @param key the key retrieve the value. + * @param opt Read options. * @return a byte array storing the value associated with the input key if * any. null if it does not find the specified key. * @@ -156,6 +160,64 @@ public class RocksDB { return get(nativeHandle_, opt.nativeHandle_, key, key.length); } + /** + * Returns a map of keys for which values were found in DB. + * + * @param keys List of keys for which values need to be retrieved. + * @return Map where key of map is the key passed by user and value for map + * entry is the corresponding value in DB. + * + * @see RocksDBException + */ + public Map multiGet(List keys) + throws RocksDBException { + assert(keys.size() != 0); + + List values = multiGet( + nativeHandle_, keys, keys.size()); + + Map keyValueMap = new HashMap(); + for(int i = 0; i < values.size(); i++) { + if(values.get(i) == null) { + continue; + } + + keyValueMap.put(keys.get(i), values.get(i)); + } + + return keyValueMap; + } + + + /** + * Returns a map of keys for which values were found in DB. + * + * @param List of keys for which values need to be retrieved. + * @param opt Read options. + * @return Map where key of map is the key passed by user and value for map + * entry is the corresponding value in DB. + * + * @see RocksDBException + */ + public Map multiGet(ReadOptions opt, List keys) + throws RocksDBException { + assert(keys.size() != 0); + + List values = multiGet( + nativeHandle_, opt.nativeHandle_, keys, keys.size()); + + Map keyValueMap = new HashMap(); + for(int i = 0; i < values.size(); i++) { + if(values.get(i) == null) { + continue; + } + + keyValueMap.put(keys.get(i), values.get(i)); + } + + return keyValueMap; + } + /** * Remove the database entry (if any) for "key". Returns OK on * success, and a non-OK status on error. It is not an error if "key" @@ -229,6 +291,10 @@ public class RocksDB { protected native int get( long handle, long readOptHandle, byte[] key, int keyLen, byte[] value, int valueLen) throws RocksDBException; + protected native List multiGet( + long dbHandle, List keys, int keysCount); + protected native List multiGet( + long dbHandle, long rOptHandle, List keys, int keysCount); protected native byte[] get( long handle, byte[] key, int keyLen) throws RocksDBException; protected native byte[] get( diff --git a/java/rocksjni/portal.h b/java/rocksjni/portal.h index 4c4444329..7d70eecae 100644 --- a/java/rocksjni/portal.h +++ b/java/rocksjni/portal.h @@ -315,5 +315,69 @@ class FilterJni { reinterpret_cast(op)); } }; + +class ListJni { + public: + // Get the java class id of java.util.List. + static jclass getListClass(JNIEnv* env) { + static jclass jclazz = env->FindClass("java/util/List"); + assert(jclazz != nullptr); + return jclazz; + } + + // Get the java class id of java.util.ArrayList. + static jclass getArrayListClass(JNIEnv* env) { + static jclass jclazz = env->FindClass("java/util/ArrayList"); + assert(jclazz != nullptr); + return jclazz; + } + + // Get the java class id of java.util.Iterator. + static jclass getIteratorClass(JNIEnv* env) { + static jclass jclazz = env->FindClass("java/util/Iterator"); + assert(jclazz != nullptr); + return jclazz; + } + + // Get the java method id of java.util.List.iterator(). + static jmethodID getIteratorMethod(JNIEnv* env) { + static jmethodID mid = env->GetMethodID( + getListClass(env), "iterator", "()Ljava/util/Iterator;"); + assert(mid != nullptr); + return mid; + } + + // Get the java method id of java.util.Iterator.hasNext(). + static jmethodID getHasNextMethod(JNIEnv* env) { + static jmethodID mid = env->GetMethodID( + getIteratorClass(env), "hasNext", "()Z"); + assert(mid != nullptr); + return mid; + } + + // Get the java method id of java.util.Iterator.next(). + static jmethodID getNextMethod(JNIEnv* env) { + static jmethodID mid = env->GetMethodID( + getIteratorClass(env), "next", "()Ljava/lang/Object;"); + assert(mid != nullptr); + return mid; + } + + // Get the java method id of arrayList constructor. + static jmethodID getArrayListConstructorMethodId(JNIEnv* env, jclass jclazz) { + static jmethodID mid = env->GetMethodID( + jclazz, "", "(I)V"); + assert(mid != nullptr); + return mid; + } + + // Get the java method id of java.util.List.add(). + static jmethodID getListAddMethodId(JNIEnv* env) { + static jmethodID mid = env->GetMethodID( + getListClass(env), "add", "(Ljava/lang/Object;)Z"); + assert(mid != nullptr); + return mid; + } +}; } // namespace rocksdb #endif // JAVA_ROCKSJNI_PORTAL_H_ diff --git a/java/rocksjni/rocksjni.cc b/java/rocksjni/rocksjni.cc index 17c7b8b10..94c41392d 100644 --- a/java/rocksjni/rocksjni.cc +++ b/java/rocksjni/rocksjni.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include "include/org_rocksdb_RocksDB.h" #include "rocksjni/portal.h" @@ -244,6 +245,91 @@ jint rocksdb_get_helper( return cvalue_len; } +jobject multi_get_helper(JNIEnv* env, jobject jdb, rocksdb::DB* db, + const rocksdb::ReadOptions& rOpt, jobject jkey_list, jint jkeys_count) { + std::vector keys; + std::vector keys_to_free; + + // get iterator + jobject iteratorObj = env->CallObjectMethod( + jkey_list, rocksdb::ListJni::getIteratorMethod(env)); + + // iterate over keys and convert java byte array to slice + while(env->CallBooleanMethod( + iteratorObj, rocksdb::ListJni::getHasNextMethod(env)) == JNI_TRUE) { + jbyteArray jkey = (jbyteArray) env->CallObjectMethod( + iteratorObj, rocksdb::ListJni::getNextMethod(env)); + jint key_length = env->GetArrayLength(jkey); + + jbyte* key = new jbyte[key_length]; + env->GetByteArrayRegion(jkey, 0, key_length, key); + // store allocated jbyte to free it after multiGet call + keys_to_free.push_back(key); + + rocksdb::Slice key_slice( + reinterpret_cast(key), key_length); + keys.push_back(key_slice); + } + + std::vector values; + std::vector s = db->MultiGet(rOpt, keys, &values); + + // Don't reuse class pointer + jclass jclazz = env->FindClass("java/util/ArrayList"); + jmethodID mid = rocksdb::ListJni::getArrayListConstructorMethodId( + env, jclazz); + jobject jvalue_list = env->NewObject(jclazz, mid, jkeys_count); + + // insert in java list + for(std::vector::size_type i = 0; i != s.size(); i++) { + if(s[i].ok()) { + jbyteArray jvalue = env->NewByteArray(values[i].size()); + env->SetByteArrayRegion( + jvalue, 0, values[i].size(), + reinterpret_cast(values[i].c_str())); + env->CallBooleanMethod( + jvalue_list, rocksdb::ListJni::getListAddMethodId(env), jvalue); + } + else { + env->CallBooleanMethod( + jvalue_list, rocksdb::ListJni::getListAddMethodId(env), nullptr); + } + } + + // free up allocated byte arrays + for(std::vector::size_type i = 0; i != keys_to_free.size(); i++) { + delete[] keys_to_free[i]; + } + keys_to_free.clear(); + + return jvalue_list; +} + +/* + * Class: org_rocksdb_RocksDB + * Method: multiGet + * Signature: (JLjava/util/List;I)Ljava/util/List; + */ +jobject Java_org_rocksdb_RocksDB_multiGet__JLjava_util_List_2I( + JNIEnv* env, jobject jdb, jlong jdb_handle, + jobject jkey_list, jint jkeys_count) { + return multi_get_helper(env, jdb, reinterpret_cast(jdb_handle), + rocksdb::ReadOptions(), jkey_list, jkeys_count); +} + +/* + * Class: org_rocksdb_RocksDB + * Method: multiGet + * Signature: (JJLjava/util/List;I)Ljava/util/List; + */ +jobject Java_org_rocksdb_RocksDB_multiGet__JJLjava_util_List_2I( + JNIEnv* env, jobject jdb, jlong jdb_handle, + jlong jropt_handle, jobject jkey_list, jint jkeys_count) { + return multi_get_helper(env, jdb, reinterpret_cast(jdb_handle), + *reinterpret_cast(jropt_handle), jkey_list, + jkeys_count); +} + /* * Class: org_rocksdb_RocksDB * Method: get