#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import gc
import random
import sys
import time
from os import path

import numpy as np


kSampleSize = 16  # The sample size used when performing eviction.
kMicrosInSecond = 1000000
kSecondsInMinute = 60
kSecondsInHour = 3600


class TraceRecord:
    """
    A trace record represents a block access.
    It holds the same struct as BlockCacheTraceRecord in
    trace_replay/block_cache_tracer.h
    """

    def __init__(
        self,
        access_time,
        block_id,
        block_type,
        block_size,
        cf_id,
        cf_name,
        level,
        fd,
        caller,
        no_insert,
        get_id,
        key_id,
        kv_size,
        is_hit,
    ):
        self.access_time = access_time
        self.block_id = block_id
        self.block_type = block_type
        self.block_size = block_size
        self.cf_id = cf_id
        self.cf_name = cf_name
        self.level = level
        self.fd = fd
        self.caller = caller
        if no_insert == 1:
            self.no_insert = True
        else:
            self.no_insert = False
        self.get_id = get_id
        self.key_id = key_id
        self.kv_size = kv_size
        if is_hit == 1:
            self.is_hit = True
        else:
            self.is_hit = False


class CacheEntry:
    """A cache entry stored in the cache."""

    def __init__(self, value_size, cf_id, level, block_type, access_number):
        self.value_size = value_size
        self.last_access_number = access_number
        self.num_hits = 0
        self.cf_id = 0
        self.level = level
        self.block_type = block_type

    def __repr__(self):
        """Debug string."""
        return "s={},last={},hits={},cf={},l={},bt={}".format(
            self.value_size,
            self.last_access_number,
            self.num_hits,
            self.cf_id,
            self.level,
            self.block_type,
        )


class HashEntry:
    """A hash entry stored in a hash table."""

    def __init__(self, key, hash, value):
        self.key = key
        self.hash = hash
        self.value = value

    def __repr__(self):
        return "k={},h={},v=[{}]".format(self.key, self.hash, self.value)


class HashTable:
    """
    A custom implementation of hash table to support fast random sampling.
    It is closed hashing and uses chaining to resolve hash conflicts.
    It grows/shrinks the hash table upon insertion/deletion to support
    fast lookups and random samplings.
    """

    def __init__(self):
        self.table = [None] * 32
        self.elements = 0

    def random_sample(self, sample_size):
        """Randomly sample 'sample_size' hash entries from the table."""
        samples = []
        index = random.randint(0, len(self.table))
        pos = (index + 1) % len(self.table)
        searches = 0
        # Starting from index, adding hash entries to the sample list until
        # sample_size is met or we ran out of entries.
        while pos != index and len(samples) < sample_size:
            if self.table[pos] is not None:
                for i in range(len(self.table[pos])):
                    if self.table[pos][i] is None:
                        continue
                    samples.append(self.table[pos][i])
                    if len(samples) > sample_size:
                        break
            pos += 1
            pos = pos % len(self.table)
            searches += 1
        return samples

    def insert(self, key, hash, value):
        """
        Insert a hash entry in the table. Replace the old entry if it already
        exists.
        """
        self.grow()
        inserted = False
        index = hash % len(self.table)
        if self.table[index] is None:
            self.table[index] = []
        for i in range(len(self.table[index])):
            if self.table[index][i] is not None:
                if (
                    self.table[index][i].hash == hash
                    and self.table[index][i].key == key
                ):
                    # The entry already exists in the table.
                    self.table[index][i] = HashEntry(key, hash, value)
                    return
                continue
            self.table[index][i] = HashEntry(key, hash, value)
            inserted = True
            break
        if not inserted:
            self.table[index].append(HashEntry(key, hash, value))
        self.elements += 1

    def resize(self, new_size):
        if new_size == len(self.table):
            return
        if new_size == 0:
            return
        if self.elements < 100:
            return
        new_table = [None] * new_size
        # Copy 'self.table' to new_table.
        for i in range(len(self.table)):
            entries = self.table[i]
            if entries is None:
                continue
            for j in range(len(entries)):
                if entries[j] is None:
                    continue
                index = entries[j].hash % new_size
                if new_table[index] is None:
                    new_table[index] = []
                new_table[index].append(entries[j])
        self.table = new_table
        del new_table
        # Manually call python gc here to free the memory as 'self.table'
        # might be very large.
        gc.collect()

    def grow(self):
        if self.elements < len(self.table):
            return
        new_size = int(len(self.table) * 1.2)
        self.resize(new_size)

    def delete(self, key, hash):
        index = hash % len(self.table)
        entries = self.table[index]
        deleted = False
        if entries is None:
            return
        for i in range(len(entries)):
            if (
                entries[i] is not None
                and entries[i].hash == hash
                and entries[i].key == key
            ):
                entries[i] = None
                self.elements -= 1
                deleted = True
                break
        if deleted:
            self.shrink()

    def shrink(self):
        if self.elements * 2 >= len(self.table):
            return
        new_size = int(len(self.table) * 0.7)
        self.resize(new_size)

    def lookup(self, key, hash):
        index = hash % len(self.table)
        entries = self.table[index]
        if entries is None:
            return None
        for entry in entries:
            if entry is not None and entry.hash == hash and entry.key == key:
                return entry.value
        return None


class MissRatioStats:
    def __init__(self, time_unit):
        self.num_misses = 0
        self.num_accesses = 0
        self.time_unit = time_unit
        self.time_misses = {}
        self.time_accesses = {}

    def update_metrics(self, access_time, is_hit):
        access_time /= kMicrosInSecond * self.time_unit
        self.num_accesses += 1
        if access_time not in self.time_accesses:
            self.time_accesses[access_time] = 0
        self.time_accesses[access_time] += 1
        if not is_hit:
            self.num_misses += 1
            if access_time not in self.time_misses:
                self.time_misses[access_time] = 0
            self.time_misses[access_time] += 1

    def reset_counter(self):
        self.num_misses = 0
        self.num_accesses = 0

    def miss_ratio(self):
        return float(self.num_misses) * 100.0 / float(self.num_accesses)

    def write_miss_timeline(self, cache_type, cache_size, result_dir, start, end):
        start /= kMicrosInSecond * self.time_unit
        end /= kMicrosInSecond * self.time_unit
        header_file_path = "{}/header-ml-miss-timeline-{}-{}-{}".format(
            result_dir, self.time_unit, cache_type, cache_size
        )
        if not path.exists(header_file_path):
            with open(header_file_path, "w+") as header_file:
                header = "time"
                for trace_time in range(start, end):
                    header += ",{}".format(trace_time)
                header_file.write(header + "\n")
        file_path = "{}/data-ml-miss-timeline-{}-{}-{}".format(
            result_dir, self.time_unit, cache_type, cache_size
        )
        with open(file_path, "w+") as file:
            row = "{}".format(cache_type)
            for trace_time in range(start, end):
                row += ",{}".format(self.time_misses.get(trace_time, 0))
            file.write(row + "\n")

    def write_miss_ratio_timeline(self, cache_type, cache_size, result_dir, start, end):
        start /= kMicrosInSecond * self.time_unit
        end /= kMicrosInSecond * self.time_unit
        header_file_path = "{}/header-ml-miss-ratio-timeline-{}-{}-{}".format(
            result_dir, self.time_unit, cache_type, cache_size
        )
        if not path.exists(header_file_path):
            with open(header_file_path, "w+") as header_file:
                header = "time"
                for trace_time in range(start, end):
                    header += ",{}".format(trace_time)
                header_file.write(header + "\n")
        file_path = "{}/data-ml-miss-ratio-timeline-{}-{}-{}".format(
            result_dir, self.time_unit, cache_type, cache_size
        )
        with open(file_path, "w+") as file:
            row = "{}".format(cache_type)
            for trace_time in range(start, end):
                naccesses = self.time_accesses.get(trace_time, 0)
                miss_ratio = 0
                if naccesses > 0:
                    miss_ratio = float(
                        self.time_misses.get(trace_time, 0) * 100.0
                    ) / float(naccesses)
                row += ",{0:.2f}".format(miss_ratio)
            file.write(row + "\n")


class PolicyStats:
    def __init__(self, time_unit, policies):
        self.time_selected_polices = {}
        self.time_accesses = {}
        self.policy_names = {}
        self.time_unit = time_unit
        for i in range(len(policies)):
            self.policy_names[i] = policies[i].policy_name()

    def update_metrics(self, access_time, selected_policy):
        access_time /= kMicrosInSecond * self.time_unit
        if access_time not in self.time_accesses:
            self.time_accesses[access_time] = 0
        self.time_accesses[access_time] += 1
        if access_time not in self.time_selected_polices:
            self.time_selected_polices[access_time] = {}
        policy_name = self.policy_names[selected_policy]
        if policy_name not in self.time_selected_polices[access_time]:
            self.time_selected_polices[access_time][policy_name] = 0
        self.time_selected_polices[access_time][policy_name] += 1

    def write_policy_timeline(self, cache_type, cache_size, result_dir, start, end):
        start /= kMicrosInSecond * self.time_unit
        end /= kMicrosInSecond * self.time_unit
        header_file_path = "{}/header-ml-policy-timeline-{}-{}-{}".format(
            result_dir, self.time_unit, cache_type, cache_size
        )
        if not path.exists(header_file_path):
            with open(header_file_path, "w+") as header_file:
                header = "time"
                for trace_time in range(start, end):
                    header += ",{}".format(trace_time)
                header_file.write(header + "\n")
        file_path = "{}/data-ml-policy-timeline-{}-{}-{}".format(
            result_dir, self.time_unit, cache_type, cache_size
        )
        with open(file_path, "w+") as file:
            for policy in self.policy_names:
                policy_name = self.policy_names[policy]
                row = "{}-{}".format(cache_type, policy_name)
                for trace_time in range(start, end):
                    row += ",{}".format(
                        self.time_selected_polices.get(trace_time, {}).get(
                            policy_name, 0
                        )
                    )
                file.write(row + "\n")

    def write_policy_ratio_timeline(
        self, cache_type, cache_size, file_path, start, end
    ):
        start /= kMicrosInSecond * self.time_unit
        end /= kMicrosInSecond * self.time_unit
        header_file_path = "{}/header-ml-policy-ratio-timeline-{}-{}-{}".format(
            result_dir, self.time_unit, cache_type, cache_size
        )
        if not path.exists(header_file_path):
            with open(header_file_path, "w+") as header_file:
                header = "time"
                for trace_time in range(start, end):
                    header += ",{}".format(trace_time)
                header_file.write(header + "\n")
        file_path = "{}/data-ml-policy-ratio-timeline-{}-{}-{}".format(
            result_dir, self.time_unit, cache_type, cache_size
        )
        with open(file_path, "w+") as file:
            for policy in self.policy_names:
                policy_name = self.policy_names[policy]
                row = "{}-{}".format(cache_type, policy_name)
                for trace_time in range(start, end):
                    naccesses = self.time_accesses.get(trace_time, 0)
                    ratio = 0
                    if naccesses > 0:
                        ratio = float(
                            self.time_selected_polices.get(trace_time, {}).get(
                                policy_name, 0
                            )
                            * 100.0
                        ) / float(naccesses)
                    row += ",{0:.2f}".format(ratio)
                file.write(row + "\n")


class Policy(object):
    """
    A policy maintains a set of evicted keys. It returns a reward of one to
    itself if it has not evicted a missing key. Otherwise, it gives itself 0
    reward.
    """

    def __init__(self):
        self.evicted_keys = {}

    def evict(self, key, max_size):
        self.evicted_keys[key] = 0

    def delete(self, key):
        self.evicted_keys.pop(key, None)

    def prioritize_samples(self, samples):
        raise NotImplementedError

    def policy_name(self):
        raise NotImplementedError

    def generate_reward(self, key):
        if key in self.evicted_keys:
            return 0
        return 1


class LRUPolicy(Policy):
    def prioritize_samples(self, samples):
        return sorted(
            samples,
            cmp=lambda e1, e2: e1.value.last_access_number
            - e2.value.last_access_number,
        )

    def policy_name(self):
        return "lru"


class MRUPolicy(Policy):
    def prioritize_samples(self, samples):
        return sorted(
            samples,
            cmp=lambda e1, e2: e2.value.last_access_number
            - e1.value.last_access_number,
        )

    def policy_name(self):
        return "mru"


class LFUPolicy(Policy):
    def prioritize_samples(self, samples):
        return sorted(samples, cmp=lambda e1, e2: e1.value.num_hits - e2.value.num_hits)

    def policy_name(self):
        return "lfu"


class MLCache(object):
    def __init__(self, cache_size, enable_cache_row_key, policies):
        self.cache_size = cache_size
        self.used_size = 0
        self.miss_ratio_stats = MissRatioStats(kSecondsInMinute)
        self.policy_stats = PolicyStats(kSecondsInMinute, policies)
        self.per_hour_miss_ratio_stats = MissRatioStats(kSecondsInHour)
        self.per_hour_policy_stats = PolicyStats(kSecondsInHour, policies)
        self.table = HashTable()
        self.enable_cache_row_key = enable_cache_row_key
        self.get_id_row_key_map = {}
        self.policies = policies

    def _lookup(self, key, hash):
        value = self.table.lookup(key, hash)
        if value is not None:
            value.last_access_number = self.miss_ratio_stats.num_accesses
            value.num_hits += 1
            return True
        return False

    def _select_policy(self, trace_record, key):
        raise NotImplementedError

    def cache_name(self):
        raise NotImplementedError

    def _evict(self, policy_index, value_size):
        # Randomly sample n entries.
        samples = self.table.random_sample(kSampleSize)
        samples = self.policies[policy_index].prioritize_samples(samples)
        for hash_entry in samples:
            self.used_size -= hash_entry.value.value_size
            self.table.delete(hash_entry.key, hash_entry.hash)
            self.policies[policy_index].evict(
                key=hash_entry.key, max_size=self.table.elements
            )
            if self.used_size + value_size <= self.cache_size:
                break

    def _insert(self, trace_record, key, hash, value_size):
        if value_size > self.cache_size:
            return
        policy_index = self._select_policy(trace_record, key)
        self.policies[policy_index].delete(key)
        self.policy_stats.update_metrics(trace_record.access_time, policy_index)
        self.per_hour_policy_stats.update_metrics(
            trace_record.access_time, policy_index
        )
        while self.used_size + value_size > self.cache_size:
            self._evict(policy_index, value_size)
        self.table.insert(
            key,
            hash,
            CacheEntry(
                value_size,
                trace_record.cf_id,
                trace_record.level,
                trace_record.block_type,
                self.miss_ratio_stats.num_accesses,
            ),
        )
        self.used_size += value_size

    def _access_kv(self, trace_record, key, hash, value_size, no_insert):
        if self._lookup(key, hash):
            return True
        if not no_insert and value_size > 0:
            self._insert(trace_record, key, hash, value_size)
        return False

    def _update_stats(self, access_time, is_hit):
        self.miss_ratio_stats.update_metrics(access_time, is_hit)
        self.per_hour_miss_ratio_stats.update_metrics(access_time, is_hit)

    def access(self, trace_record):
        assert self.used_size <= self.cache_size
        if (
            self.enable_cache_row_key
            and trace_record.caller == 1
            and trace_record.key_id != 0
            and trace_record.get_id != 0
        ):
            # This is a get request.
            if trace_record.get_id not in self.get_id_row_key_map:
                self.get_id_row_key_map[trace_record.get_id] = {}
                self.get_id_row_key_map[trace_record.get_id]["h"] = False
            if self.get_id_row_key_map[trace_record.get_id]["h"]:
                # We treat future accesses as hits since this get request
                # completes.
                self._update_stats(trace_record.access_time, is_hit=True)
                return
            if trace_record.key_id not in self.get_id_row_key_map[trace_record.get_id]:
                # First time seen this key.
                is_hit = self._access_kv(
                    trace_record,
                    key="g{}".format(trace_record.key_id),
                    hash=trace_record.key_id,
                    value_size=trace_record.kv_size,
                    no_insert=False,
                )
                inserted = False
                if trace_record.kv_size > 0:
                    inserted = True
                self.get_id_row_key_map[trace_record.get_id][
                    trace_record.key_id
                ] = inserted
                self.get_id_row_key_map[trace_record.get_id]["h"] = is_hit
            if self.get_id_row_key_map[trace_record.get_id]["h"]:
                # We treat future accesses as hits since this get request
                # completes.
                self._update_stats(trace_record.access_time, is_hit=True)
                return
            # Access its blocks.
            is_hit = self._access_kv(
                trace_record,
                key="b{}".format(trace_record.block_id),
                hash=trace_record.block_id,
                value_size=trace_record.block_size,
                no_insert=trace_record.no_insert,
            )
            self._update_stats(trace_record.access_time, is_hit)
            if (
                trace_record.kv_size > 0
                and not self.get_id_row_key_map[trace_record.get_id][
                    trace_record.key_id
                ]
            ):
                # Insert the row key-value pair.
                self._access_kv(
                    trace_record,
                    key="g{}".format(trace_record.key_id),
                    hash=trace_record.key_id,
                    value_size=trace_record.kv_size,
                    no_insert=False,
                )
                # Mark as inserted.
                self.get_id_row_key_map[trace_record.get_id][trace_record.key_id] = True
            return
        # Access the block.
        is_hit = self._access_kv(
            trace_record,
            key="b{}".format(trace_record.block_id),
            hash=trace_record.block_id,
            value_size=trace_record.block_size,
            no_insert=trace_record.no_insert,
        )
        self._update_stats(trace_record.access_time, is_hit)


class ThompsonSamplingCache(MLCache):
    """
    An implementation of Thompson Sampling for the Bernoulli Bandit [1].
    [1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband,
    and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found.
    Trends Mach. Learn. 11, 1 (July 2018), 1-96.
    DOI: https://doi.org/10.1561/2200000070
    """

    def __init__(self, cache_size, enable_cache_row_key, policies, init_a=1, init_b=1):
        super(ThompsonSamplingCache, self).__init__(
            cache_size, enable_cache_row_key, policies
        )
        self._as = {}
        self._bs = {}
        for _i in range(len(policies)):
            self._as = [init_a] * len(self.policies)
            self._bs = [init_b] * len(self.policies)

    def _select_policy(self, trace_record, key):
        samples = [
            np.random.beta(self._as[x], self._bs[x]) for x in range(len(self.policies))
        ]
        selected_policy = max(range(len(self.policies)), key=lambda x: samples[x])
        reward = self.policies[selected_policy].generate_reward(key)
        assert reward <= 1 and reward >= 0
        self._as[selected_policy] += reward
        self._bs[selected_policy] += 1 - reward
        return selected_policy

    def cache_name(self):
        if self.enable_cache_row_key:
            return "Hybrid ThompsonSampling (ts_hybrid)"
        return "ThompsonSampling (ts)"


class LinUCBCache(MLCache):
    """
    An implementation of LinUCB with disjoint linear models [2].
    [2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010.
    A contextual-bandit approach to personalized news article recommendation.
    In Proceedings of the 19th international conference on World wide web
    (WWW '10). ACM, New York, NY, USA, 661-670.
    DOI=http://dx.doi.org/10.1145/1772690.1772758
    """

    def __init__(self, cache_size, enable_cache_row_key, policies):
        super(LinUCBCache, self).__init__(cache_size, enable_cache_row_key, policies)
        self.nfeatures = 4  # Block type, caller, level, cf.
        self.th = np.zeros((len(self.policies), self.nfeatures))
        self.eps = 0.2
        self.b = np.zeros_like(self.th)
        self.A = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
        self.A_inv = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
        for i in range(len(self.policies)):
            self.A[i] = np.identity(self.nfeatures)
        self.th_hat = np.zeros_like(self.th)
        self.p = np.zeros(len(self.policies))
        self.alph = 0.2

    def _select_policy(self, trace_record, key):
        x_i = np.zeros(self.nfeatures)  # The current context vector
        x_i[0] = trace_record.block_type
        x_i[1] = trace_record.caller
        x_i[2] = trace_record.level
        x_i[3] = trace_record.cf_id
        p = np.zeros(len(self.policies))
        for a in range(len(self.policies)):
            self.th_hat[a] = self.A_inv[a].dot(self.b[a])
            ta = x_i.dot(self.A_inv[a]).dot(x_i)
            a_upper_ci = self.alph * np.sqrt(ta)
            a_mean = self.th_hat[a].dot(x_i)
            p[a] = a_mean + a_upper_ci
        p = p + (np.random.random(len(p)) * 0.000001)
        selected_policy = p.argmax()
        reward = self.policies[selected_policy].generate_reward(key)
        assert reward <= 1 and reward >= 0
        self.A[selected_policy] += np.outer(x_i, x_i)
        self.b[selected_policy] += reward * x_i
        self.A_inv[selected_policy] = np.linalg.inv(self.A[selected_policy])
        del x_i
        return selected_policy

    def cache_name(self):
        if self.enable_cache_row_key:
            return "Hybrid LinUCB (linucb_hybrid)"
        return "LinUCB (linucb)"


def parse_cache_size(cs):
    cs = cs.replace("\n", "")
    if cs[-1] == "M":
        return int(cs[: len(cs) - 1]) * 1024 * 1024
    if cs[-1] == "G":
        return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024
    if cs[-1] == "T":
        return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024 * 1024
    return int(cs)


def create_cache(cache_type, cache_size, downsample_size):
    policies = []
    policies.append(LRUPolicy())
    policies.append(MRUPolicy())
    policies.append(LFUPolicy())
    cache_size = cache_size / downsample_size
    enable_cache_row_key = False
    if "hybrid" in cache_type:
        enable_cache_row_key = True
        cache_type = cache_type[:-7]
    if cache_type == "ts":
        return ThompsonSamplingCache(cache_size, enable_cache_row_key, policies)
    elif cache_type == "linucb":
        return LinUCBCache(cache_size, enable_cache_row_key, policies)
    else:
        print("Unknown cache type {}".format(cache_type))
        assert False
    return None


def run(trace_file_path, cache_type, cache, warmup_seconds):
    warmup_complete = False
    num = 0
    trace_start_time = 0
    trace_duration = 0
    start_time = time.time()
    time_interval = 1
    trace_miss_ratio_stats = MissRatioStats(kSecondsInMinute)
    with open(trace_file_path, "r") as trace_file:
        for line in trace_file:
            num += 1
            if num % 1000000 == 0:
                # Force a python gc periodically to reduce memory usage.
                gc.collect()
            ts = line.split(",")
            timestamp = int(ts[0])
            if trace_start_time == 0:
                trace_start_time = timestamp
            trace_duration = timestamp - trace_start_time
            if not warmup_complete and trace_duration > warmup_seconds * 1000000:
                cache.miss_ratio_stats.reset_counter()
                warmup_complete = True
            record = TraceRecord(
                access_time=int(ts[0]),
                block_id=int(ts[1]),
                block_type=int(ts[2]),
                block_size=int(ts[3]),
                cf_id=int(ts[4]),
                cf_name=ts[5],
                level=int(ts[6]),
                fd=int(ts[7]),
                caller=int(ts[8]),
                no_insert=int(ts[9]),
                get_id=int(ts[10]),
                key_id=int(ts[11]),
                kv_size=int(ts[12]),
                is_hit=int(ts[13]),
            )
            trace_miss_ratio_stats.update_metrics(
                record.access_time, is_hit=record.is_hit
            )
            cache.access(record)
            del record
            if num % 100 != 0:
                continue
            # Report progress every 10 seconds.
            now = time.time()
            if now - start_time > time_interval * 10:
                print(
                    "Take {} seconds to process {} trace records with trace "
                    "duration of {} seconds. Throughput: {} records/second. "
                    "Trace miss ratio {}".format(
                        now - start_time,
                        num,
                        trace_duration / 1000000,
                        num / (now - start_time),
                        trace_miss_ratio_stats.miss_ratio(),
                    )
                )
                time_interval += 1
                print(
                    "{},0,0,{},{},{}".format(
                        cache_type,
                        cache.cache_size,
                        cache.miss_ratio_stats.miss_ratio(),
                        cache.miss_ratio_stats.num_accesses,
                    )
                )
    now = time.time()
    print(
        "Take {} seconds to process {} trace records with trace duration of {} "
        "seconds. Throughput: {} records/second. Trace miss ratio {}".format(
            now - start_time,
            num,
            trace_duration / 1000000,
            num / (now - start_time),
            trace_miss_ratio_stats.miss_ratio(),
        )
    )
    return trace_start_time, trace_duration


def report_stats(
    cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
):
    cache_label = "{}-{}".format(cache_type, cache_size)
    with open("{}/data-ml-mrc-{}".format(result_dir, cache_label), "w+") as mrc_file:
        mrc_file.write(
            "{},0,0,{},{},{}\n".format(
                cache_type,
                cache_size,
                cache.miss_ratio_stats.miss_ratio(),
                cache.miss_ratio_stats.num_accesses,
            )
        )
    cache.policy_stats.write_policy_timeline(
        cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )
    cache.policy_stats.write_policy_ratio_timeline(
        cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )
    cache.miss_ratio_stats.write_miss_timeline(
        cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )
    cache.miss_ratio_stats.write_miss_ratio_timeline(
        cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )
    cache.per_hour_policy_stats.write_policy_timeline(
        cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )
    cache.per_hour_policy_stats.write_policy_ratio_timeline(
        cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )
    cache.per_hour_miss_ratio_stats.write_miss_timeline(
        cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )
    cache.per_hour_miss_ratio_stats.write_miss_ratio_timeline(
        cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )


if __name__ == "__main__":
    if len(sys.argv) <= 6:
        print(
            "Must provide 6 arguments. "
            "1) cache_type (ts, ts_hybrid, linucb, linucb_hybrid). "
            "2) cache size (xM, xG, xT). "
            "3) The sampling frequency used to collect the trace. (The "
            "simulation scales down the cache size by the sampling frequency). "
            "4) Warmup seconds (The number of seconds used for warmup). "
            "5) Trace file path. "
            "6) Result directory (A directory that saves generated results)"
        )
        exit(1)
    cache_type = sys.argv[1]
    cache_size = parse_cache_size(sys.argv[2])
    downsample_size = int(sys.argv[3])
    warmup_seconds = int(sys.argv[4])
    trace_file_path = sys.argv[5]
    result_dir = sys.argv[6]
    cache = create_cache(cache_type, cache_size, downsample_size)
    trace_start_time, trace_duration = run(
        trace_file_path, cache_type, cache, warmup_seconds
    )
    trace_end_time = trace_start_time + trace_duration
    report_stats(
        cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
    )