70c7302fb5
Summary: This PR implements cache eviction using reinforcement learning. It includes two implementations: 1. An implementation of Thompson Sampling for the Bernoulli Bandit [1]. 2. An implementation of LinUCB with disjoint linear models [2]. The idea is that a cache uses multiple eviction policies, e.g., MRU, LRU, and LFU. The cache learns which eviction policy is the best and uses it upon a cache miss. Thompson Sampling is contextless and does not include any features. LinUCB includes features such as level, block type, caller, column family id to decide which eviction policy to use. [1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband, and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found. Trends Mach. Learn. 11, 1 (July 2018), 1-96. DOI: https://doi.org/10.1561/2200000070 [2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010. A contextual-bandit approach to personalized news article recommendation. In Proceedings of the 19th international conference on World wide web (WWW '10). ACM, New York, NY, USA, 661-670. DOI=http://dx.doi.org/10.1145/1772690.1772758 Pull Request resolved: https://github.com/facebook/rocksdb/pull/5610 Differential Revision: D16435067 Pulled By: HaoyuHuang fbshipit-source-id: 6549239ae14115c01cb1e70548af9e46d8dc21bb
865 lines
31 KiB
Python
865 lines
31 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
|
|
import gc
|
|
import random
|
|
import sys
|
|
import time
|
|
from os import path
|
|
|
|
import numpy as np
|
|
|
|
|
|
kSampleSize = 16 # The sample size used when performing eviction.
|
|
kMicrosInSecond = 1000000
|
|
kSecondsInMinute = 60
|
|
kSecondsInHour = 3600
|
|
|
|
|
|
class TraceRecord:
|
|
"""
|
|
A trace record represents a block access.
|
|
It holds the same struct as BlockCacheTraceRecord in
|
|
trace_replay/block_cache_tracer.h
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
access_time,
|
|
block_id,
|
|
block_type,
|
|
block_size,
|
|
cf_id,
|
|
cf_name,
|
|
level,
|
|
fd,
|
|
caller,
|
|
no_insert,
|
|
get_id,
|
|
key_id,
|
|
kv_size,
|
|
is_hit,
|
|
):
|
|
self.access_time = access_time
|
|
self.block_id = block_id
|
|
self.block_type = block_type
|
|
self.block_size = block_size
|
|
self.cf_id = cf_id
|
|
self.cf_name = cf_name
|
|
self.level = level
|
|
self.fd = fd
|
|
self.caller = caller
|
|
if no_insert == 1:
|
|
self.no_insert = True
|
|
else:
|
|
self.no_insert = False
|
|
self.get_id = get_id
|
|
self.key_id = key_id
|
|
self.kv_size = kv_size
|
|
if is_hit == 1:
|
|
self.is_hit = True
|
|
else:
|
|
self.is_hit = False
|
|
|
|
|
|
class CacheEntry:
|
|
"""A cache entry stored in the cache."""
|
|
|
|
def __init__(self, value_size, cf_id, level, block_type, access_number):
|
|
self.value_size = value_size
|
|
self.last_access_number = access_number
|
|
self.num_hits = 0
|
|
self.cf_id = 0
|
|
self.level = level
|
|
self.block_type = block_type
|
|
|
|
def __repr__(self):
|
|
"""Debug string."""
|
|
return "s={},last={},hits={},cf={},l={},bt={}".format(
|
|
self.value_size,
|
|
self.last_access_number,
|
|
self.num_hits,
|
|
self.cf_id,
|
|
self.level,
|
|
self.block_type,
|
|
)
|
|
|
|
|
|
class HashEntry:
|
|
"""A hash entry stored in a hash table."""
|
|
|
|
def __init__(self, key, hash, value):
|
|
self.key = key
|
|
self.hash = hash
|
|
self.value = value
|
|
|
|
def __repr__(self):
|
|
return "k={},h={},v=[{}]".format(self.key, self.hash, self.value)
|
|
|
|
|
|
class HashTable:
|
|
"""
|
|
A custom implementation of hash table to support fast random sampling.
|
|
It is closed hashing and uses chaining to resolve hash conflicts.
|
|
It grows/shrinks the hash table upon insertion/deletion to support
|
|
fast lookups and random samplings.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.table = [None] * 32
|
|
self.elements = 0
|
|
|
|
def random_sample(self, sample_size):
|
|
"""Randomly sample 'sample_size' hash entries from the table."""
|
|
samples = []
|
|
index = random.randint(0, len(self.table))
|
|
pos = (index + 1) % len(self.table)
|
|
searches = 0
|
|
# Starting from index, adding hash entries to the sample list until
|
|
# sample_size is met or we ran out of entries.
|
|
while pos != index and len(samples) < sample_size:
|
|
if self.table[pos] is not None:
|
|
for i in range(len(self.table[pos])):
|
|
if self.table[pos][i] is None:
|
|
continue
|
|
samples.append(self.table[pos][i])
|
|
if len(samples) > sample_size:
|
|
break
|
|
pos += 1
|
|
pos = pos % len(self.table)
|
|
searches += 1
|
|
return samples
|
|
|
|
def insert(self, key, hash, value):
|
|
"""
|
|
Insert a hash entry in the table. Replace the old entry if it already
|
|
exists.
|
|
"""
|
|
self.grow()
|
|
inserted = False
|
|
index = hash % len(self.table)
|
|
if self.table[index] is None:
|
|
self.table[index] = []
|
|
for i in range(len(self.table[index])):
|
|
if self.table[index][i] is not None:
|
|
if (
|
|
self.table[index][i].hash == hash
|
|
and self.table[index][i].key == key
|
|
):
|
|
# The entry already exists in the table.
|
|
self.table[index][i] = HashEntry(key, hash, value)
|
|
return
|
|
continue
|
|
self.table[index][i] = HashEntry(key, hash, value)
|
|
inserted = True
|
|
break
|
|
if not inserted:
|
|
self.table[index].append(HashEntry(key, hash, value))
|
|
self.elements += 1
|
|
|
|
def resize(self, new_size):
|
|
if new_size == len(self.table):
|
|
return
|
|
if new_size == 0:
|
|
return
|
|
if self.elements < 100:
|
|
return
|
|
new_table = [None] * new_size
|
|
# Copy 'self.table' to new_table.
|
|
for i in range(len(self.table)):
|
|
entries = self.table[i]
|
|
if entries is None:
|
|
continue
|
|
for j in range(len(entries)):
|
|
if entries[j] is None:
|
|
continue
|
|
index = entries[j].hash % new_size
|
|
if new_table[index] is None:
|
|
new_table[index] = []
|
|
new_table[index].append(entries[j])
|
|
self.table = new_table
|
|
del new_table
|
|
# Manually call python gc here to free the memory as 'self.table'
|
|
# might be very large.
|
|
gc.collect()
|
|
|
|
def grow(self):
|
|
if self.elements < len(self.table):
|
|
return
|
|
new_size = int(len(self.table) * 1.2)
|
|
self.resize(new_size)
|
|
|
|
def delete(self, key, hash):
|
|
index = hash % len(self.table)
|
|
entries = self.table[index]
|
|
deleted = False
|
|
if entries is None:
|
|
return
|
|
for i in range(len(entries)):
|
|
if (
|
|
entries[i] is not None
|
|
and entries[i].hash == hash
|
|
and entries[i].key == key
|
|
):
|
|
entries[i] = None
|
|
self.elements -= 1
|
|
deleted = True
|
|
break
|
|
if deleted:
|
|
self.shrink()
|
|
|
|
def shrink(self):
|
|
if self.elements * 2 >= len(self.table):
|
|
return
|
|
new_size = int(len(self.table) * 0.7)
|
|
self.resize(new_size)
|
|
|
|
def lookup(self, key, hash):
|
|
index = hash % len(self.table)
|
|
entries = self.table[index]
|
|
if entries is None:
|
|
return None
|
|
for entry in entries:
|
|
if entry is not None and entry.hash == hash and entry.key == key:
|
|
return entry.value
|
|
return None
|
|
|
|
|
|
class MissRatioStats:
|
|
def __init__(self, time_unit):
|
|
self.num_misses = 0
|
|
self.num_accesses = 0
|
|
self.time_unit = time_unit
|
|
self.time_misses = {}
|
|
self.time_accesses = {}
|
|
|
|
def update_metrics(self, access_time, is_hit):
|
|
access_time /= kMicrosInSecond * self.time_unit
|
|
self.num_accesses += 1
|
|
if access_time not in self.time_accesses:
|
|
self.time_accesses[access_time] = 0
|
|
self.time_accesses[access_time] += 1
|
|
if not is_hit:
|
|
self.num_misses += 1
|
|
if access_time not in self.time_misses:
|
|
self.time_misses[access_time] = 0
|
|
self.time_misses[access_time] += 1
|
|
|
|
def reset_counter(self):
|
|
self.num_misses = 0
|
|
self.num_accesses = 0
|
|
|
|
def miss_ratio(self):
|
|
return float(self.num_misses) * 100.0 / float(self.num_accesses)
|
|
|
|
def write_miss_timeline(self, cache_type, cache_size, result_dir, start, end):
|
|
start /= kMicrosInSecond * self.time_unit
|
|
end /= kMicrosInSecond * self.time_unit
|
|
header_file_path = "{}/header-ml-miss-timeline-{}-{}-{}".format(
|
|
result_dir, self.time_unit, cache_type, cache_size
|
|
)
|
|
if not path.exists(header_file_path):
|
|
with open(header_file_path, "w+") as header_file:
|
|
header = "time"
|
|
for trace_time in range(start, end):
|
|
header += ",{}".format(trace_time)
|
|
header_file.write(header + "\n")
|
|
file_path = "{}/data-ml-miss-timeline-{}-{}-{}".format(
|
|
result_dir, self.time_unit, cache_type, cache_size
|
|
)
|
|
with open(file_path, "w+") as file:
|
|
row = "{}".format(cache_type)
|
|
for trace_time in range(start, end):
|
|
row += ",{}".format(self.time_misses.get(trace_time, 0))
|
|
file.write(row + "\n")
|
|
|
|
def write_miss_ratio_timeline(self, cache_type, cache_size, result_dir, start, end):
|
|
start /= kMicrosInSecond * self.time_unit
|
|
end /= kMicrosInSecond * self.time_unit
|
|
header_file_path = "{}/header-ml-miss-ratio-timeline-{}-{}-{}".format(
|
|
result_dir, self.time_unit, cache_type, cache_size
|
|
)
|
|
if not path.exists(header_file_path):
|
|
with open(header_file_path, "w+") as header_file:
|
|
header = "time"
|
|
for trace_time in range(start, end):
|
|
header += ",{}".format(trace_time)
|
|
header_file.write(header + "\n")
|
|
file_path = "{}/data-ml-miss-ratio-timeline-{}-{}-{}".format(
|
|
result_dir, self.time_unit, cache_type, cache_size
|
|
)
|
|
with open(file_path, "w+") as file:
|
|
row = "{}".format(cache_type)
|
|
for trace_time in range(start, end):
|
|
naccesses = self.time_accesses.get(trace_time, 0)
|
|
miss_ratio = 0
|
|
if naccesses > 0:
|
|
miss_ratio = float(
|
|
self.time_misses.get(trace_time, 0) * 100.0
|
|
) / float(naccesses)
|
|
row += ",{0:.2f}".format(miss_ratio)
|
|
file.write(row + "\n")
|
|
|
|
|
|
class PolicyStats:
|
|
def __init__(self, time_unit, policies):
|
|
self.time_selected_polices = {}
|
|
self.time_accesses = {}
|
|
self.policy_names = {}
|
|
self.time_unit = time_unit
|
|
for i in range(len(policies)):
|
|
self.policy_names[i] = policies[i].policy_name()
|
|
|
|
def update_metrics(self, access_time, selected_policy):
|
|
access_time /= kMicrosInSecond * self.time_unit
|
|
if access_time not in self.time_accesses:
|
|
self.time_accesses[access_time] = 0
|
|
self.time_accesses[access_time] += 1
|
|
if access_time not in self.time_selected_polices:
|
|
self.time_selected_polices[access_time] = {}
|
|
policy_name = self.policy_names[selected_policy]
|
|
if policy_name not in self.time_selected_polices[access_time]:
|
|
self.time_selected_polices[access_time][policy_name] = 0
|
|
self.time_selected_polices[access_time][policy_name] += 1
|
|
|
|
def write_policy_timeline(self, cache_type, cache_size, result_dir, start, end):
|
|
start /= kMicrosInSecond * self.time_unit
|
|
end /= kMicrosInSecond * self.time_unit
|
|
header_file_path = "{}/header-ml-policy-timeline-{}-{}-{}".format(
|
|
result_dir, self.time_unit, cache_type, cache_size
|
|
)
|
|
if not path.exists(header_file_path):
|
|
with open(header_file_path, "w+") as header_file:
|
|
header = "time"
|
|
for trace_time in range(start, end):
|
|
header += ",{}".format(trace_time)
|
|
header_file.write(header + "\n")
|
|
file_path = "{}/data-ml-policy-timeline-{}-{}-{}".format(
|
|
result_dir, self.time_unit, cache_type, cache_size
|
|
)
|
|
with open(file_path, "w+") as file:
|
|
for policy in self.policy_names:
|
|
policy_name = self.policy_names[policy]
|
|
row = "{}-{}".format(cache_type, policy_name)
|
|
for trace_time in range(start, end):
|
|
row += ",{}".format(
|
|
self.time_selected_polices.get(trace_time, {}).get(
|
|
policy_name, 0
|
|
)
|
|
)
|
|
file.write(row + "\n")
|
|
|
|
def write_policy_ratio_timeline(
|
|
self, cache_type, cache_size, file_path, start, end
|
|
):
|
|
start /= kMicrosInSecond * self.time_unit
|
|
end /= kMicrosInSecond * self.time_unit
|
|
header_file_path = "{}/header-ml-policy-ratio-timeline-{}-{}-{}".format(
|
|
result_dir, self.time_unit, cache_type, cache_size
|
|
)
|
|
if not path.exists(header_file_path):
|
|
with open(header_file_path, "w+") as header_file:
|
|
header = "time"
|
|
for trace_time in range(start, end):
|
|
header += ",{}".format(trace_time)
|
|
header_file.write(header + "\n")
|
|
file_path = "{}/data-ml-policy-ratio-timeline-{}-{}-{}".format(
|
|
result_dir, self.time_unit, cache_type, cache_size
|
|
)
|
|
with open(file_path, "w+") as file:
|
|
for policy in self.policy_names:
|
|
policy_name = self.policy_names[policy]
|
|
row = "{}-{}".format(cache_type, policy_name)
|
|
for trace_time in range(start, end):
|
|
naccesses = self.time_accesses.get(trace_time, 0)
|
|
ratio = 0
|
|
if naccesses > 0:
|
|
ratio = float(
|
|
self.time_selected_polices.get(trace_time, {}).get(
|
|
policy_name, 0
|
|
)
|
|
* 100.0
|
|
) / float(naccesses)
|
|
row += ",{0:.2f}".format(ratio)
|
|
file.write(row + "\n")
|
|
|
|
|
|
class Policy(object):
|
|
"""
|
|
A policy maintains a set of evicted keys. It returns a reward of one to
|
|
itself if it has not evicted a missing key. Otherwise, it gives itself 0
|
|
reward.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.evicted_keys = {}
|
|
|
|
def evict(self, key, max_size):
|
|
self.evicted_keys[key] = 0
|
|
|
|
def delete(self, key):
|
|
self.evicted_keys.pop(key, None)
|
|
|
|
def prioritize_samples(self, samples):
|
|
raise NotImplementedError
|
|
|
|
def policy_name(self):
|
|
raise NotImplementedError
|
|
|
|
def generate_reward(self, key):
|
|
if key in self.evicted_keys:
|
|
return 0
|
|
return 1
|
|
|
|
|
|
class LRUPolicy(Policy):
|
|
def prioritize_samples(self, samples):
|
|
return sorted(
|
|
samples,
|
|
cmp=lambda e1, e2: e1.value.last_access_number
|
|
- e2.value.last_access_number,
|
|
)
|
|
|
|
def policy_name(self):
|
|
return "lru"
|
|
|
|
|
|
class MRUPolicy(Policy):
|
|
def prioritize_samples(self, samples):
|
|
return sorted(
|
|
samples,
|
|
cmp=lambda e1, e2: e2.value.last_access_number
|
|
- e1.value.last_access_number,
|
|
)
|
|
|
|
def policy_name(self):
|
|
return "mru"
|
|
|
|
|
|
class LFUPolicy(Policy):
|
|
def prioritize_samples(self, samples):
|
|
return sorted(samples, cmp=lambda e1, e2: e1.value.num_hits - e2.value.num_hits)
|
|
|
|
def policy_name(self):
|
|
return "lfu"
|
|
|
|
|
|
class MLCache(object):
|
|
def __init__(self, cache_size, enable_cache_row_key, policies):
|
|
self.cache_size = cache_size
|
|
self.used_size = 0
|
|
self.miss_ratio_stats = MissRatioStats(kSecondsInMinute)
|
|
self.policy_stats = PolicyStats(kSecondsInMinute, policies)
|
|
self.per_hour_miss_ratio_stats = MissRatioStats(kSecondsInHour)
|
|
self.per_hour_policy_stats = PolicyStats(kSecondsInHour, policies)
|
|
self.table = HashTable()
|
|
self.enable_cache_row_key = enable_cache_row_key
|
|
self.get_id_row_key_map = {}
|
|
self.policies = policies
|
|
|
|
def _lookup(self, key, hash):
|
|
value = self.table.lookup(key, hash)
|
|
if value is not None:
|
|
value.last_access_number = self.miss_ratio_stats.num_accesses
|
|
value.num_hits += 1
|
|
return True
|
|
return False
|
|
|
|
def _select_policy(self, trace_record, key):
|
|
raise NotImplementedError
|
|
|
|
def cache_name(self):
|
|
raise NotImplementedError
|
|
|
|
def _evict(self, policy_index, value_size):
|
|
# Randomly sample n entries.
|
|
samples = self.table.random_sample(kSampleSize)
|
|
samples = self.policies[policy_index].prioritize_samples(samples)
|
|
for hash_entry in samples:
|
|
self.used_size -= hash_entry.value.value_size
|
|
self.table.delete(hash_entry.key, hash_entry.hash)
|
|
self.policies[policy_index].evict(
|
|
key=hash_entry.key, max_size=self.table.elements
|
|
)
|
|
if self.used_size + value_size <= self.cache_size:
|
|
break
|
|
|
|
def _insert(self, trace_record, key, hash, value_size):
|
|
if value_size > self.cache_size:
|
|
return
|
|
policy_index = self._select_policy(trace_record, key)
|
|
self.policies[policy_index].delete(key)
|
|
self.policy_stats.update_metrics(trace_record.access_time, policy_index)
|
|
self.per_hour_policy_stats.update_metrics(
|
|
trace_record.access_time, policy_index
|
|
)
|
|
while self.used_size + value_size > self.cache_size:
|
|
self._evict(policy_index, value_size)
|
|
self.table.insert(
|
|
key,
|
|
hash,
|
|
CacheEntry(
|
|
value_size,
|
|
trace_record.cf_id,
|
|
trace_record.level,
|
|
trace_record.block_type,
|
|
self.miss_ratio_stats.num_accesses,
|
|
),
|
|
)
|
|
self.used_size += value_size
|
|
|
|
def _access_kv(self, trace_record, key, hash, value_size, no_insert):
|
|
if self._lookup(key, hash):
|
|
return True
|
|
if not no_insert and value_size > 0:
|
|
self._insert(trace_record, key, hash, value_size)
|
|
return False
|
|
|
|
def _update_stats(self, access_time, is_hit):
|
|
self.miss_ratio_stats.update_metrics(access_time, is_hit)
|
|
self.per_hour_miss_ratio_stats.update_metrics(access_time, is_hit)
|
|
|
|
def access(self, trace_record):
|
|
assert self.used_size <= self.cache_size
|
|
if (
|
|
self.enable_cache_row_key
|
|
and trace_record.caller == 1
|
|
and trace_record.key_id != 0
|
|
and trace_record.get_id != 0
|
|
):
|
|
# This is a get request.
|
|
if trace_record.get_id not in self.get_id_row_key_map:
|
|
self.get_id_row_key_map[trace_record.get_id] = {}
|
|
self.get_id_row_key_map[trace_record.get_id]["h"] = False
|
|
if self.get_id_row_key_map[trace_record.get_id]["h"]:
|
|
# We treat future accesses as hits since this get request
|
|
# completes.
|
|
self._update_stats(trace_record.access_time, is_hit=True)
|
|
return
|
|
if trace_record.key_id not in self.get_id_row_key_map[trace_record.get_id]:
|
|
# First time seen this key.
|
|
is_hit = self._access_kv(
|
|
trace_record,
|
|
key="g{}".format(trace_record.key_id),
|
|
hash=trace_record.key_id,
|
|
value_size=trace_record.kv_size,
|
|
no_insert=False,
|
|
)
|
|
inserted = False
|
|
if trace_record.kv_size > 0:
|
|
inserted = True
|
|
self.get_id_row_key_map[trace_record.get_id][
|
|
trace_record.key_id
|
|
] = inserted
|
|
self.get_id_row_key_map[trace_record.get_id]["h"] = is_hit
|
|
if self.get_id_row_key_map[trace_record.get_id]["h"]:
|
|
# We treat future accesses as hits since this get request
|
|
# completes.
|
|
self._update_stats(trace_record.access_time, is_hit=True)
|
|
return
|
|
# Access its blocks.
|
|
is_hit = self._access_kv(
|
|
trace_record,
|
|
key="b{}".format(trace_record.block_id),
|
|
hash=trace_record.block_id,
|
|
value_size=trace_record.block_size,
|
|
no_insert=trace_record.no_insert,
|
|
)
|
|
self._update_stats(trace_record.access_time, is_hit)
|
|
if (
|
|
trace_record.kv_size > 0
|
|
and not self.get_id_row_key_map[trace_record.get_id][
|
|
trace_record.key_id
|
|
]
|
|
):
|
|
# Insert the row key-value pair.
|
|
self._access_kv(
|
|
trace_record,
|
|
key="g{}".format(trace_record.key_id),
|
|
hash=trace_record.key_id,
|
|
value_size=trace_record.kv_size,
|
|
no_insert=False,
|
|
)
|
|
# Mark as inserted.
|
|
self.get_id_row_key_map[trace_record.get_id][trace_record.key_id] = True
|
|
return
|
|
# Access the block.
|
|
is_hit = self._access_kv(
|
|
trace_record,
|
|
key="b{}".format(trace_record.block_id),
|
|
hash=trace_record.block_id,
|
|
value_size=trace_record.block_size,
|
|
no_insert=trace_record.no_insert,
|
|
)
|
|
self._update_stats(trace_record.access_time, is_hit)
|
|
|
|
|
|
class ThompsonSamplingCache(MLCache):
|
|
"""
|
|
An implementation of Thompson Sampling for the Bernoulli Bandit [1].
|
|
[1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband,
|
|
and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found.
|
|
Trends Mach. Learn. 11, 1 (July 2018), 1-96.
|
|
DOI: https://doi.org/10.1561/2200000070
|
|
"""
|
|
|
|
def __init__(self, cache_size, enable_cache_row_key, policies, init_a=1, init_b=1):
|
|
super(ThompsonSamplingCache, self).__init__(
|
|
cache_size, enable_cache_row_key, policies
|
|
)
|
|
self._as = {}
|
|
self._bs = {}
|
|
for _i in range(len(policies)):
|
|
self._as = [init_a] * len(self.policies)
|
|
self._bs = [init_b] * len(self.policies)
|
|
|
|
def _select_policy(self, trace_record, key):
|
|
samples = [
|
|
np.random.beta(self._as[x], self._bs[x]) for x in range(len(self.policies))
|
|
]
|
|
selected_policy = max(range(len(self.policies)), key=lambda x: samples[x])
|
|
reward = self.policies[selected_policy].generate_reward(key)
|
|
assert reward <= 1 and reward >= 0
|
|
self._as[selected_policy] += reward
|
|
self._bs[selected_policy] += 1 - reward
|
|
return selected_policy
|
|
|
|
def cache_name(self):
|
|
if self.enable_cache_row_key:
|
|
return "Hybrid ThompsonSampling (ts_hybrid)"
|
|
return "ThompsonSampling (ts)"
|
|
|
|
|
|
class LinUCBCache(MLCache):
|
|
"""
|
|
An implementation of LinUCB with disjoint linear models [2].
|
|
[2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010.
|
|
A contextual-bandit approach to personalized news article recommendation.
|
|
In Proceedings of the 19th international conference on World wide web
|
|
(WWW '10). ACM, New York, NY, USA, 661-670.
|
|
DOI=http://dx.doi.org/10.1145/1772690.1772758
|
|
"""
|
|
|
|
def __init__(self, cache_size, enable_cache_row_key, policies):
|
|
super(LinUCBCache, self).__init__(cache_size, enable_cache_row_key, policies)
|
|
self.nfeatures = 4 # Block type, caller, level, cf.
|
|
self.th = np.zeros((len(self.policies), self.nfeatures))
|
|
self.eps = 0.2
|
|
self.b = np.zeros_like(self.th)
|
|
self.A = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
|
|
self.A_inv = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
|
|
for i in range(len(self.policies)):
|
|
self.A[i] = np.identity(self.nfeatures)
|
|
self.th_hat = np.zeros_like(self.th)
|
|
self.p = np.zeros(len(self.policies))
|
|
self.alph = 0.2
|
|
|
|
def _select_policy(self, trace_record, key):
|
|
x_i = np.zeros(self.nfeatures) # The current context vector
|
|
x_i[0] = trace_record.block_type
|
|
x_i[1] = trace_record.caller
|
|
x_i[2] = trace_record.level
|
|
x_i[3] = trace_record.cf_id
|
|
p = np.zeros(len(self.policies))
|
|
for a in range(len(self.policies)):
|
|
self.th_hat[a] = self.A_inv[a].dot(self.b[a])
|
|
ta = x_i.dot(self.A_inv[a]).dot(x_i)
|
|
a_upper_ci = self.alph * np.sqrt(ta)
|
|
a_mean = self.th_hat[a].dot(x_i)
|
|
p[a] = a_mean + a_upper_ci
|
|
p = p + (np.random.random(len(p)) * 0.000001)
|
|
selected_policy = p.argmax()
|
|
reward = self.policies[selected_policy].generate_reward(key)
|
|
assert reward <= 1 and reward >= 0
|
|
self.A[selected_policy] += np.outer(x_i, x_i)
|
|
self.b[selected_policy] += reward * x_i
|
|
self.A_inv[selected_policy] = np.linalg.inv(self.A[selected_policy])
|
|
del x_i
|
|
return selected_policy
|
|
|
|
def cache_name(self):
|
|
if self.enable_cache_row_key:
|
|
return "Hybrid LinUCB (linucb_hybrid)"
|
|
return "LinUCB (linucb)"
|
|
|
|
|
|
def parse_cache_size(cs):
|
|
cs = cs.replace("\n", "")
|
|
if cs[-1] == "M":
|
|
return int(cs[: len(cs) - 1]) * 1024 * 1024
|
|
if cs[-1] == "G":
|
|
return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024
|
|
if cs[-1] == "T":
|
|
return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024 * 1024
|
|
return int(cs)
|
|
|
|
|
|
def create_cache(cache_type, cache_size, downsample_size):
|
|
policies = []
|
|
policies.append(LRUPolicy())
|
|
policies.append(MRUPolicy())
|
|
policies.append(LFUPolicy())
|
|
cache_size = cache_size / downsample_size
|
|
enable_cache_row_key = False
|
|
if "hybrid" in cache_type:
|
|
enable_cache_row_key = True
|
|
cache_type = cache_type[:-7]
|
|
if cache_type == "ts":
|
|
return ThompsonSamplingCache(cache_size, enable_cache_row_key, policies)
|
|
elif cache_type == "linucb":
|
|
return LinUCBCache(cache_size, enable_cache_row_key, policies)
|
|
else:
|
|
print("Unknown cache type {}".format(cache_type))
|
|
assert False
|
|
return None
|
|
|
|
|
|
def run(trace_file_path, cache_type, cache, warmup_seconds):
|
|
warmup_complete = False
|
|
num = 0
|
|
trace_start_time = 0
|
|
trace_duration = 0
|
|
start_time = time.time()
|
|
time_interval = 1
|
|
trace_miss_ratio_stats = MissRatioStats(kSecondsInMinute)
|
|
with open(trace_file_path, "r") as trace_file:
|
|
for line in trace_file:
|
|
num += 1
|
|
if num % 1000000 == 0:
|
|
# Force a python gc periodically to reduce memory usage.
|
|
gc.collect()
|
|
ts = line.split(",")
|
|
timestamp = int(ts[0])
|
|
if trace_start_time == 0:
|
|
trace_start_time = timestamp
|
|
trace_duration = timestamp - trace_start_time
|
|
if not warmup_complete and trace_duration > warmup_seconds * 1000000:
|
|
cache.miss_ratio_stats.reset_counter()
|
|
warmup_complete = True
|
|
record = TraceRecord(
|
|
access_time=int(ts[0]),
|
|
block_id=int(ts[1]),
|
|
block_type=int(ts[2]),
|
|
block_size=int(ts[3]),
|
|
cf_id=int(ts[4]),
|
|
cf_name=ts[5],
|
|
level=int(ts[6]),
|
|
fd=int(ts[7]),
|
|
caller=int(ts[8]),
|
|
no_insert=int(ts[9]),
|
|
get_id=int(ts[10]),
|
|
key_id=int(ts[11]),
|
|
kv_size=int(ts[12]),
|
|
is_hit=int(ts[13]),
|
|
)
|
|
trace_miss_ratio_stats.update_metrics(
|
|
record.access_time, is_hit=record.is_hit
|
|
)
|
|
cache.access(record)
|
|
del record
|
|
if num % 100 != 0:
|
|
continue
|
|
# Report progress every 10 seconds.
|
|
now = time.time()
|
|
if now - start_time > time_interval * 10:
|
|
print(
|
|
"Take {} seconds to process {} trace records with trace "
|
|
"duration of {} seconds. Throughput: {} records/second. "
|
|
"Trace miss ratio {}".format(
|
|
now - start_time,
|
|
num,
|
|
trace_duration / 1000000,
|
|
num / (now - start_time),
|
|
trace_miss_ratio_stats.miss_ratio(),
|
|
)
|
|
)
|
|
time_interval += 1
|
|
print(
|
|
"{},0,0,{},{},{}".format(
|
|
cache_type,
|
|
cache.cache_size,
|
|
cache.miss_ratio_stats.miss_ratio(),
|
|
cache.miss_ratio_stats.num_accesses,
|
|
)
|
|
)
|
|
now = time.time()
|
|
print(
|
|
"Take {} seconds to process {} trace records with trace duration of {} "
|
|
"seconds. Throughput: {} records/second. Trace miss ratio {}".format(
|
|
now - start_time,
|
|
num,
|
|
trace_duration / 1000000,
|
|
num / (now - start_time),
|
|
trace_miss_ratio_stats.miss_ratio(),
|
|
)
|
|
)
|
|
return trace_start_time, trace_duration
|
|
|
|
|
|
def report_stats(
|
|
cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
):
|
|
cache_label = "{}-{}".format(cache_type, cache_size)
|
|
with open("{}/data-ml-mrc-{}".format(result_dir, cache_label), "w+") as mrc_file:
|
|
mrc_file.write(
|
|
"{},0,0,{},{},{}\n".format(
|
|
cache_type,
|
|
cache_size,
|
|
cache.miss_ratio_stats.miss_ratio(),
|
|
cache.miss_ratio_stats.num_accesses,
|
|
)
|
|
)
|
|
cache.policy_stats.write_policy_timeline(
|
|
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|
|
cache.policy_stats.write_policy_ratio_timeline(
|
|
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|
|
cache.miss_ratio_stats.write_miss_timeline(
|
|
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|
|
cache.miss_ratio_stats.write_miss_ratio_timeline(
|
|
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|
|
cache.per_hour_policy_stats.write_policy_timeline(
|
|
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|
|
cache.per_hour_policy_stats.write_policy_ratio_timeline(
|
|
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|
|
cache.per_hour_miss_ratio_stats.write_miss_timeline(
|
|
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|
|
cache.per_hour_miss_ratio_stats.write_miss_ratio_timeline(
|
|
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) <= 6:
|
|
print(
|
|
"Must provide 6 arguments. "
|
|
"1) cache_type (ts, ts_hybrid, linucb, linucb_hybrid). "
|
|
"2) cache size (xM, xG, xT). "
|
|
"3) The sampling frequency used to collect the trace. (The "
|
|
"simulation scales down the cache size by the sampling frequency). "
|
|
"4) Warmup seconds (The number of seconds used for warmup). "
|
|
"5) Trace file path. "
|
|
"6) Result directory (A directory that saves generated results)"
|
|
)
|
|
exit(1)
|
|
cache_type = sys.argv[1]
|
|
cache_size = parse_cache_size(sys.argv[2])
|
|
downsample_size = int(sys.argv[3])
|
|
warmup_seconds = int(sys.argv[4])
|
|
trace_file_path = sys.argv[5]
|
|
result_dir = sys.argv[6]
|
|
cache = create_cache(cache_type, cache_size, downsample_size)
|
|
trace_start_time, trace_duration = run(
|
|
trace_file_path, cache_type, cache, warmup_seconds
|
|
)
|
|
trace_end_time = trace_start_time + trace_duration
|
|
report_stats(
|
|
cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
|
|
)
|