import numpy as np

import bandits
import decomposition
import framework
import parameters

PARAM_GRID = [0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]
NUM_RUNS = 5

N = 50
K = 5
C = 4
T = 10**5

MUS_MODEL = framework.generate_means

def tune_base_threshold():
    for val in PARAM_GRID:
        total_regrets = []
        for _ in range(NUM_RUNS):
            parameters.override_param('base_threshold', val)

            mus = MUS_MODEL(N, K)
            instance = framework.AnonymousBandits(mus, C, T)
            alg = bandits.MatchAlgorithm(instance, decomposition.GreedyDecomposition, K)

            while instance.is_running():
                alg.do_round()
        
            total_regrets.append(instance.regret_timeseries()[T-1])

        avg_regret = np.mean(total_regrets)
        print("{}: {}".format(val, avg_regret))

def tune_etc_explore_length():
    for val in PARAM_GRID:
        total_regrets = []
        for _ in range(NUM_RUNS):
            parameters.override_param('etc_explore_length', val)

            mus = MUS_MODEL(N, K)
            instance = framework.AnonymousBandits(mus, C, T)
            alg = bandits.ExploreThenCommit(instance)

            while instance.is_running():
                alg.do_round()
        
            total_regrets.append(instance.regret_timeseries()[T-1])

        avg_regret = np.mean(total_regrets)
        print("{}: {}".format(val, avg_regret))

def tune_ucb_ci_size():
    for val in PARAM_GRID:
        total_regrets = []
        for _ in range(NUM_RUNS):
            parameters.override_param('ucb_ci_size', val)

            mus = MUS_MODEL(N, K)
            instance = framework.AnonymousBandits(mus, C, T)
            alg = bandits.IndependentUCB(instance)

            while instance.is_running():
                alg.do_round()
        
            total_regrets.append(instance.regret_timeseries()[T-1])

        avg_regret = np.mean(total_regrets)
        print("{}: {}".format(val, avg_regret))

if __name__ == '__main__':
    tune_base_threshold()
    # tune_etc_explore_length()
    # tune_ucb_ci_size()
