import os
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats


def mean_confidence_interval(data, confidence=0.95, axis=None):
    n = len(data)
    m, se = np.mean(data, axis=axis), scipy.stats.sem(data, axis=axis)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, h


def regret_plot(environment_name, agents, alg_storage, legends, save_directory=None, run_dict=None):
    plt.figure()
    for agent in agents:
        if run_dict is not None:
            runs = run_dict[agent]
        else:
            filename = agent + '_' + alg_storage[agent]
            read_directory = os.path.join('log', environment_name, filename)
            runs = []
            try:
                for name in sorted(os.listdir(read_directory)):
                    if name.isdigit():
                        runs.append(np.loadtxt(os.path.join(read_directory, name)))
            except FileNotFoundError:
                raise FileNotFoundError('Stored data for {0} on {1} is not found. '
                                        'Need to remove {2} from the environments_name'
                                        ' or run the experiment for it.'.format(agent, environment_name, environment_name))
        t = runs[0][:, 0]

        regrets = []
        for run in runs:
            regrets.append(run[:, 1])

        avg_reg, conf = mean_confidence_interval(regrets, confidence=0.95, axis=0)

        if agent in ['QLearningAgent', 'HQLAgent']:
            linestyle = '-.'
        else:
            linestyle = '-'

        plt.plot(t, avg_reg, label=legends[agent], linestyle=linestyle)
        plt.xlabel('Episode')
        plt.ylabel('Regret')
        plt.fill_between(t, avg_reg - conf, avg_reg + conf, alpha=0.2)
        if environment_name == 'RandomMDPEnv':
            plt.xlim(0, 3000)
            plt.ylim(0, 120)
        elif environment_name == 'GridWorldEnv':
            plt.xlim(0, 3000)
            plt.ylim(0, 1300)
        else:
            raise Exception('unknown environment')

    plt.legend(loc=1)
    plt.title(environment_name[:-3])
    if save_directory:
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
        plt.savefig(os.path.join(save_directory, '{}.pdf'.format(environment_name)))
    else:
        plt.show()
