import json
import numpy as np
import matplotlib.pyplot as plt

PREFIX = 'linear_'
NUM_FILES = 20
SUBTITLE = "linear model"

def plot_data(prefix, num_files, title):
    data = []
    for i in range(1, num_files + 1):
        data.append(json.load(open('data/{}{}.json'.format(prefix, i), 'r')))

    def aggregate(key):
        N = len(data[0][key])
        average = np.zeros(N)
        second_moment = np.zeros(N)

        for run in data:
            key_data = np.array(run[key])
            average += key_data
            second_moment += key_data**2
        
        M = len(data)
        std_deviation = np.sqrt(second_moment/(M-1) - (average**2)/(M*(M-1)))
        t_statistic = 2.086

        confidence_interval = t_statistic * std_deviation / np.sqrt(M)

        average /= M
        
        return average, confidence_interval

    ALG_TYPES = [('etc', "Explore-then-commit"),
                ('greedy_decomposition', "Greedy decomposition"),
                ('random_decomposition', "Random decomposition"),
                ('lp_decomposition', "LP decomposition"),
                ('ucb_regret', "UCB (non-anonymous)")]

    def plot_alg(key, desc):
        average, confidence_interval = aggregate(key)
        plt.plot(average, label=desc)

        plt.fill_between(range(len(average)),
                        average-confidence_interval,
                        average+confidence_interval,
                        color='grey',
                        alpha=0.1)

    plt.rcParams["figure.figsize"] = (8,6)

    for key, desc in ALG_TYPES:
        plot_alg(key, desc)

    plt.title("N=50, K=5, C=4, T=10^5 ({})".format(title))
    plt.ylabel("Cumulative Regret")
    plt.xlabel("Round")
    plt.legend()
    plt.show()

if __name__ == '__main__':
    plot_data(PREFIX, NUM_FILES, SUBTITLE)