rocksdb/tools/block_cache_analyzer/block_cache_pysim.py
haoyuhuang 70c7302fb5 Block cache simulator: Add pysim to simulate caches using reinforcement learning. (#5610)
Summary:
This PR implements cache eviction using reinforcement learning. It includes two implementations:
1. An implementation of Thompson Sampling for the Bernoulli Bandit [1].
2. An implementation of LinUCB with disjoint linear models [2].

The idea is that a cache uses multiple eviction policies, e.g., MRU, LRU, and LFU. The cache learns which eviction policy is the best and uses it upon a cache miss.
Thompson Sampling is contextless and does not include any features.
LinUCB includes features such as level, block type, caller, column family id to decide which eviction policy to use.

[1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband, and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found. Trends Mach. Learn. 11, 1 (July 2018), 1-96. DOI: https://doi.org/10.1561/2200000070
[2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010. A contextual-bandit approach to personalized news article recommendation. In Proceedings of the 19th international conference on World wide web (WWW '10). ACM, New York, NY, USA, 661-670. DOI=http://dx.doi.org/10.1145/1772690.1772758
Pull Request resolved: https://github.com/facebook/rocksdb/pull/5610

Differential Revision: D16435067

Pulled By: HaoyuHuang

fbshipit-source-id: 6549239ae14115c01cb1e70548af9e46d8dc21bb
2019-07-26 14:41:13 -07:00

865 lines
31 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import gc
import random
import sys
import time
from os import path
import numpy as np
kSampleSize = 16 # The sample size used when performing eviction.
kMicrosInSecond = 1000000
kSecondsInMinute = 60
kSecondsInHour = 3600
class TraceRecord:
"""
A trace record represents a block access.
It holds the same struct as BlockCacheTraceRecord in
trace_replay/block_cache_tracer.h
"""
def __init__(
self,
access_time,
block_id,
block_type,
block_size,
cf_id,
cf_name,
level,
fd,
caller,
no_insert,
get_id,
key_id,
kv_size,
is_hit,
):
self.access_time = access_time
self.block_id = block_id
self.block_type = block_type
self.block_size = block_size
self.cf_id = cf_id
self.cf_name = cf_name
self.level = level
self.fd = fd
self.caller = caller
if no_insert == 1:
self.no_insert = True
else:
self.no_insert = False
self.get_id = get_id
self.key_id = key_id
self.kv_size = kv_size
if is_hit == 1:
self.is_hit = True
else:
self.is_hit = False
class CacheEntry:
"""A cache entry stored in the cache."""
def __init__(self, value_size, cf_id, level, block_type, access_number):
self.value_size = value_size
self.last_access_number = access_number
self.num_hits = 0
self.cf_id = 0
self.level = level
self.block_type = block_type
def __repr__(self):
"""Debug string."""
return "s={},last={},hits={},cf={},l={},bt={}".format(
self.value_size,
self.last_access_number,
self.num_hits,
self.cf_id,
self.level,
self.block_type,
)
class HashEntry:
"""A hash entry stored in a hash table."""
def __init__(self, key, hash, value):
self.key = key
self.hash = hash
self.value = value
def __repr__(self):
return "k={},h={},v=[{}]".format(self.key, self.hash, self.value)
class HashTable:
"""
A custom implementation of hash table to support fast random sampling.
It is closed hashing and uses chaining to resolve hash conflicts.
It grows/shrinks the hash table upon insertion/deletion to support
fast lookups and random samplings.
"""
def __init__(self):
self.table = [None] * 32
self.elements = 0
def random_sample(self, sample_size):
"""Randomly sample 'sample_size' hash entries from the table."""
samples = []
index = random.randint(0, len(self.table))
pos = (index + 1) % len(self.table)
searches = 0
# Starting from index, adding hash entries to the sample list until
# sample_size is met or we ran out of entries.
while pos != index and len(samples) < sample_size:
if self.table[pos] is not None:
for i in range(len(self.table[pos])):
if self.table[pos][i] is None:
continue
samples.append(self.table[pos][i])
if len(samples) > sample_size:
break
pos += 1
pos = pos % len(self.table)
searches += 1
return samples
def insert(self, key, hash, value):
"""
Insert a hash entry in the table. Replace the old entry if it already
exists.
"""
self.grow()
inserted = False
index = hash % len(self.table)
if self.table[index] is None:
self.table[index] = []
for i in range(len(self.table[index])):
if self.table[index][i] is not None:
if (
self.table[index][i].hash == hash
and self.table[index][i].key == key
):
# The entry already exists in the table.
self.table[index][i] = HashEntry(key, hash, value)
return
continue
self.table[index][i] = HashEntry(key, hash, value)
inserted = True
break
if not inserted:
self.table[index].append(HashEntry(key, hash, value))
self.elements += 1
def resize(self, new_size):
if new_size == len(self.table):
return
if new_size == 0:
return
if self.elements < 100:
return
new_table = [None] * new_size
# Copy 'self.table' to new_table.
for i in range(len(self.table)):
entries = self.table[i]
if entries is None:
continue
for j in range(len(entries)):
if entries[j] is None:
continue
index = entries[j].hash % new_size
if new_table[index] is None:
new_table[index] = []
new_table[index].append(entries[j])
self.table = new_table
del new_table
# Manually call python gc here to free the memory as 'self.table'
# might be very large.
gc.collect()
def grow(self):
if self.elements < len(self.table):
return
new_size = int(len(self.table) * 1.2)
self.resize(new_size)
def delete(self, key, hash):
index = hash % len(self.table)
entries = self.table[index]
deleted = False
if entries is None:
return
for i in range(len(entries)):
if (
entries[i] is not None
and entries[i].hash == hash
and entries[i].key == key
):
entries[i] = None
self.elements -= 1
deleted = True
break
if deleted:
self.shrink()
def shrink(self):
if self.elements * 2 >= len(self.table):
return
new_size = int(len(self.table) * 0.7)
self.resize(new_size)
def lookup(self, key, hash):
index = hash % len(self.table)
entries = self.table[index]
if entries is None:
return None
for entry in entries:
if entry is not None and entry.hash == hash and entry.key == key:
return entry.value
return None
class MissRatioStats:
def __init__(self, time_unit):
self.num_misses = 0
self.num_accesses = 0
self.time_unit = time_unit
self.time_misses = {}
self.time_accesses = {}
def update_metrics(self, access_time, is_hit):
access_time /= kMicrosInSecond * self.time_unit
self.num_accesses += 1
if access_time not in self.time_accesses:
self.time_accesses[access_time] = 0
self.time_accesses[access_time] += 1
if not is_hit:
self.num_misses += 1
if access_time not in self.time_misses:
self.time_misses[access_time] = 0
self.time_misses[access_time] += 1
def reset_counter(self):
self.num_misses = 0
self.num_accesses = 0
def miss_ratio(self):
return float(self.num_misses) * 100.0 / float(self.num_accesses)
def write_miss_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-miss-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-miss-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
row = "{}".format(cache_type)
for trace_time in range(start, end):
row += ",{}".format(self.time_misses.get(trace_time, 0))
file.write(row + "\n")
def write_miss_ratio_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-miss-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-miss-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
row = "{}".format(cache_type)
for trace_time in range(start, end):
naccesses = self.time_accesses.get(trace_time, 0)
miss_ratio = 0
if naccesses > 0:
miss_ratio = float(
self.time_misses.get(trace_time, 0) * 100.0
) / float(naccesses)
row += ",{0:.2f}".format(miss_ratio)
file.write(row + "\n")
class PolicyStats:
def __init__(self, time_unit, policies):
self.time_selected_polices = {}
self.time_accesses = {}
self.policy_names = {}
self.time_unit = time_unit
for i in range(len(policies)):
self.policy_names[i] = policies[i].policy_name()
def update_metrics(self, access_time, selected_policy):
access_time /= kMicrosInSecond * self.time_unit
if access_time not in self.time_accesses:
self.time_accesses[access_time] = 0
self.time_accesses[access_time] += 1
if access_time not in self.time_selected_polices:
self.time_selected_polices[access_time] = {}
policy_name = self.policy_names[selected_policy]
if policy_name not in self.time_selected_polices[access_time]:
self.time_selected_polices[access_time][policy_name] = 0
self.time_selected_polices[access_time][policy_name] += 1
def write_policy_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-policy-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-policy-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
for policy in self.policy_names:
policy_name = self.policy_names[policy]
row = "{}-{}".format(cache_type, policy_name)
for trace_time in range(start, end):
row += ",{}".format(
self.time_selected_polices.get(trace_time, {}).get(
policy_name, 0
)
)
file.write(row + "\n")
def write_policy_ratio_timeline(
self, cache_type, cache_size, file_path, start, end
):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-policy-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-policy-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
for policy in self.policy_names:
policy_name = self.policy_names[policy]
row = "{}-{}".format(cache_type, policy_name)
for trace_time in range(start, end):
naccesses = self.time_accesses.get(trace_time, 0)
ratio = 0
if naccesses > 0:
ratio = float(
self.time_selected_polices.get(trace_time, {}).get(
policy_name, 0
)
* 100.0
) / float(naccesses)
row += ",{0:.2f}".format(ratio)
file.write(row + "\n")
class Policy(object):
"""
A policy maintains a set of evicted keys. It returns a reward of one to
itself if it has not evicted a missing key. Otherwise, it gives itself 0
reward.
"""
def __init__(self):
self.evicted_keys = {}
def evict(self, key, max_size):
self.evicted_keys[key] = 0
def delete(self, key):
self.evicted_keys.pop(key, None)
def prioritize_samples(self, samples):
raise NotImplementedError
def policy_name(self):
raise NotImplementedError
def generate_reward(self, key):
if key in self.evicted_keys:
return 0
return 1
class LRUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(
samples,
cmp=lambda e1, e2: e1.value.last_access_number
- e2.value.last_access_number,
)
def policy_name(self):
return "lru"
class MRUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(
samples,
cmp=lambda e1, e2: e2.value.last_access_number
- e1.value.last_access_number,
)
def policy_name(self):
return "mru"
class LFUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(samples, cmp=lambda e1, e2: e1.value.num_hits - e2.value.num_hits)
def policy_name(self):
return "lfu"
class MLCache(object):
def __init__(self, cache_size, enable_cache_row_key, policies):
self.cache_size = cache_size
self.used_size = 0
self.miss_ratio_stats = MissRatioStats(kSecondsInMinute)
self.policy_stats = PolicyStats(kSecondsInMinute, policies)
self.per_hour_miss_ratio_stats = MissRatioStats(kSecondsInHour)
self.per_hour_policy_stats = PolicyStats(kSecondsInHour, policies)
self.table = HashTable()
self.enable_cache_row_key = enable_cache_row_key
self.get_id_row_key_map = {}
self.policies = policies
def _lookup(self, key, hash):
value = self.table.lookup(key, hash)
if value is not None:
value.last_access_number = self.miss_ratio_stats.num_accesses
value.num_hits += 1
return True
return False
def _select_policy(self, trace_record, key):
raise NotImplementedError
def cache_name(self):
raise NotImplementedError
def _evict(self, policy_index, value_size):
# Randomly sample n entries.
samples = self.table.random_sample(kSampleSize)
samples = self.policies[policy_index].prioritize_samples(samples)
for hash_entry in samples:
self.used_size -= hash_entry.value.value_size
self.table.delete(hash_entry.key, hash_entry.hash)
self.policies[policy_index].evict(
key=hash_entry.key, max_size=self.table.elements
)
if self.used_size + value_size <= self.cache_size:
break
def _insert(self, trace_record, key, hash, value_size):
if value_size > self.cache_size:
return
policy_index = self._select_policy(trace_record, key)
self.policies[policy_index].delete(key)
self.policy_stats.update_metrics(trace_record.access_time, policy_index)
self.per_hour_policy_stats.update_metrics(
trace_record.access_time, policy_index
)
while self.used_size + value_size > self.cache_size:
self._evict(policy_index, value_size)
self.table.insert(
key,
hash,
CacheEntry(
value_size,
trace_record.cf_id,
trace_record.level,
trace_record.block_type,
self.miss_ratio_stats.num_accesses,
),
)
self.used_size += value_size
def _access_kv(self, trace_record, key, hash, value_size, no_insert):
if self._lookup(key, hash):
return True
if not no_insert and value_size > 0:
self._insert(trace_record, key, hash, value_size)
return False
def _update_stats(self, access_time, is_hit):
self.miss_ratio_stats.update_metrics(access_time, is_hit)
self.per_hour_miss_ratio_stats.update_metrics(access_time, is_hit)
def access(self, trace_record):
assert self.used_size <= self.cache_size
if (
self.enable_cache_row_key
and trace_record.caller == 1
and trace_record.key_id != 0
and trace_record.get_id != 0
):
# This is a get request.
if trace_record.get_id not in self.get_id_row_key_map:
self.get_id_row_key_map[trace_record.get_id] = {}
self.get_id_row_key_map[trace_record.get_id]["h"] = False
if self.get_id_row_key_map[trace_record.get_id]["h"]:
# We treat future accesses as hits since this get request
# completes.
self._update_stats(trace_record.access_time, is_hit=True)
return
if trace_record.key_id not in self.get_id_row_key_map[trace_record.get_id]:
# First time seen this key.
is_hit = self._access_kv(
trace_record,
key="g{}".format(trace_record.key_id),
hash=trace_record.key_id,
value_size=trace_record.kv_size,
no_insert=False,
)
inserted = False
if trace_record.kv_size > 0:
inserted = True
self.get_id_row_key_map[trace_record.get_id][
trace_record.key_id
] = inserted
self.get_id_row_key_map[trace_record.get_id]["h"] = is_hit
if self.get_id_row_key_map[trace_record.get_id]["h"]:
# We treat future accesses as hits since this get request
# completes.
self._update_stats(trace_record.access_time, is_hit=True)
return
# Access its blocks.
is_hit = self._access_kv(
trace_record,
key="b{}".format(trace_record.block_id),
hash=trace_record.block_id,
value_size=trace_record.block_size,
no_insert=trace_record.no_insert,
)
self._update_stats(trace_record.access_time, is_hit)
if (
trace_record.kv_size > 0
and not self.get_id_row_key_map[trace_record.get_id][
trace_record.key_id
]
):
# Insert the row key-value pair.
self._access_kv(
trace_record,
key="g{}".format(trace_record.key_id),
hash=trace_record.key_id,
value_size=trace_record.kv_size,
no_insert=False,
)
# Mark as inserted.
self.get_id_row_key_map[trace_record.get_id][trace_record.key_id] = True
return
# Access the block.
is_hit = self._access_kv(
trace_record,
key="b{}".format(trace_record.block_id),
hash=trace_record.block_id,
value_size=trace_record.block_size,
no_insert=trace_record.no_insert,
)
self._update_stats(trace_record.access_time, is_hit)
class ThompsonSamplingCache(MLCache):
"""
An implementation of Thompson Sampling for the Bernoulli Bandit [1].
[1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband,
and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found.
Trends Mach. Learn. 11, 1 (July 2018), 1-96.
DOI: https://doi.org/10.1561/2200000070
"""
def __init__(self, cache_size, enable_cache_row_key, policies, init_a=1, init_b=1):
super(ThompsonSamplingCache, self).__init__(
cache_size, enable_cache_row_key, policies
)
self._as = {}
self._bs = {}
for _i in range(len(policies)):
self._as = [init_a] * len(self.policies)
self._bs = [init_b] * len(self.policies)
def _select_policy(self, trace_record, key):
samples = [
np.random.beta(self._as[x], self._bs[x]) for x in range(len(self.policies))
]
selected_policy = max(range(len(self.policies)), key=lambda x: samples[x])
reward = self.policies[selected_policy].generate_reward(key)
assert reward <= 1 and reward >= 0
self._as[selected_policy] += reward
self._bs[selected_policy] += 1 - reward
return selected_policy
def cache_name(self):
if self.enable_cache_row_key:
return "Hybrid ThompsonSampling (ts_hybrid)"
return "ThompsonSampling (ts)"
class LinUCBCache(MLCache):
"""
An implementation of LinUCB with disjoint linear models [2].
[2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010.
A contextual-bandit approach to personalized news article recommendation.
In Proceedings of the 19th international conference on World wide web
(WWW '10). ACM, New York, NY, USA, 661-670.
DOI=http://dx.doi.org/10.1145/1772690.1772758
"""
def __init__(self, cache_size, enable_cache_row_key, policies):
super(LinUCBCache, self).__init__(cache_size, enable_cache_row_key, policies)
self.nfeatures = 4 # Block type, caller, level, cf.
self.th = np.zeros((len(self.policies), self.nfeatures))
self.eps = 0.2
self.b = np.zeros_like(self.th)
self.A = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
self.A_inv = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
for i in range(len(self.policies)):
self.A[i] = np.identity(self.nfeatures)
self.th_hat = np.zeros_like(self.th)
self.p = np.zeros(len(self.policies))
self.alph = 0.2
def _select_policy(self, trace_record, key):
x_i = np.zeros(self.nfeatures) # The current context vector
x_i[0] = trace_record.block_type
x_i[1] = trace_record.caller
x_i[2] = trace_record.level
x_i[3] = trace_record.cf_id
p = np.zeros(len(self.policies))
for a in range(len(self.policies)):
self.th_hat[a] = self.A_inv[a].dot(self.b[a])
ta = x_i.dot(self.A_inv[a]).dot(x_i)
a_upper_ci = self.alph * np.sqrt(ta)
a_mean = self.th_hat[a].dot(x_i)
p[a] = a_mean + a_upper_ci
p = p + (np.random.random(len(p)) * 0.000001)
selected_policy = p.argmax()
reward = self.policies[selected_policy].generate_reward(key)
assert reward <= 1 and reward >= 0
self.A[selected_policy] += np.outer(x_i, x_i)
self.b[selected_policy] += reward * x_i
self.A_inv[selected_policy] = np.linalg.inv(self.A[selected_policy])
del x_i
return selected_policy
def cache_name(self):
if self.enable_cache_row_key:
return "Hybrid LinUCB (linucb_hybrid)"
return "LinUCB (linucb)"
def parse_cache_size(cs):
cs = cs.replace("\n", "")
if cs[-1] == "M":
return int(cs[: len(cs) - 1]) * 1024 * 1024
if cs[-1] == "G":
return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024
if cs[-1] == "T":
return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024 * 1024
return int(cs)
def create_cache(cache_type, cache_size, downsample_size):
policies = []
policies.append(LRUPolicy())
policies.append(MRUPolicy())
policies.append(LFUPolicy())
cache_size = cache_size / downsample_size
enable_cache_row_key = False
if "hybrid" in cache_type:
enable_cache_row_key = True
cache_type = cache_type[:-7]
if cache_type == "ts":
return ThompsonSamplingCache(cache_size, enable_cache_row_key, policies)
elif cache_type == "linucb":
return LinUCBCache(cache_size, enable_cache_row_key, policies)
else:
print("Unknown cache type {}".format(cache_type))
assert False
return None
def run(trace_file_path, cache_type, cache, warmup_seconds):
warmup_complete = False
num = 0
trace_start_time = 0
trace_duration = 0
start_time = time.time()
time_interval = 1
trace_miss_ratio_stats = MissRatioStats(kSecondsInMinute)
with open(trace_file_path, "r") as trace_file:
for line in trace_file:
num += 1
if num % 1000000 == 0:
# Force a python gc periodically to reduce memory usage.
gc.collect()
ts = line.split(",")
timestamp = int(ts[0])
if trace_start_time == 0:
trace_start_time = timestamp
trace_duration = timestamp - trace_start_time
if not warmup_complete and trace_duration > warmup_seconds * 1000000:
cache.miss_ratio_stats.reset_counter()
warmup_complete = True
record = TraceRecord(
access_time=int(ts[0]),
block_id=int(ts[1]),
block_type=int(ts[2]),
block_size=int(ts[3]),
cf_id=int(ts[4]),
cf_name=ts[5],
level=int(ts[6]),
fd=int(ts[7]),
caller=int(ts[8]),
no_insert=int(ts[9]),
get_id=int(ts[10]),
key_id=int(ts[11]),
kv_size=int(ts[12]),
is_hit=int(ts[13]),
)
trace_miss_ratio_stats.update_metrics(
record.access_time, is_hit=record.is_hit
)
cache.access(record)
del record
if num % 100 != 0:
continue
# Report progress every 10 seconds.
now = time.time()
if now - start_time > time_interval * 10:
print(
"Take {} seconds to process {} trace records with trace "
"duration of {} seconds. Throughput: {} records/second. "
"Trace miss ratio {}".format(
now - start_time,
num,
trace_duration / 1000000,
num / (now - start_time),
trace_miss_ratio_stats.miss_ratio(),
)
)
time_interval += 1
print(
"{},0,0,{},{},{}".format(
cache_type,
cache.cache_size,
cache.miss_ratio_stats.miss_ratio(),
cache.miss_ratio_stats.num_accesses,
)
)
now = time.time()
print(
"Take {} seconds to process {} trace records with trace duration of {} "
"seconds. Throughput: {} records/second. Trace miss ratio {}".format(
now - start_time,
num,
trace_duration / 1000000,
num / (now - start_time),
trace_miss_ratio_stats.miss_ratio(),
)
)
return trace_start_time, trace_duration
def report_stats(
cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
):
cache_label = "{}-{}".format(cache_type, cache_size)
with open("{}/data-ml-mrc-{}".format(result_dir, cache_label), "w+") as mrc_file:
mrc_file.write(
"{},0,0,{},{},{}\n".format(
cache_type,
cache_size,
cache.miss_ratio_stats.miss_ratio(),
cache.miss_ratio_stats.num_accesses,
)
)
cache.policy_stats.write_policy_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.policy_stats.write_policy_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.miss_ratio_stats.write_miss_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.miss_ratio_stats.write_miss_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_policy_stats.write_policy_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_policy_stats.write_policy_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_miss_ratio_stats.write_miss_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_miss_ratio_stats.write_miss_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
if __name__ == "__main__":
if len(sys.argv) <= 6:
print(
"Must provide 6 arguments. "
"1) cache_type (ts, ts_hybrid, linucb, linucb_hybrid). "
"2) cache size (xM, xG, xT). "
"3) The sampling frequency used to collect the trace. (The "
"simulation scales down the cache size by the sampling frequency). "
"4) Warmup seconds (The number of seconds used for warmup). "
"5) Trace file path. "
"6) Result directory (A directory that saves generated results)"
)
exit(1)
cache_type = sys.argv[1]
cache_size = parse_cache_size(sys.argv[2])
downsample_size = int(sys.argv[3])
warmup_seconds = int(sys.argv[4])
trace_file_path = sys.argv[5]
result_dir = sys.argv[6]
cache = create_cache(cache_type, cache_size, downsample_size)
trace_start_time, trace_duration = run(
trace_file_path, cache_type, cache, warmup_seconds
)
trace_end_time = trace_start_time + trace_duration
report_stats(
cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)