# If not launching julia with multiple threads, remove @everywhere command from the below lines
@everywhere include("../Top2Algos.jl")
@everywhere beta(t,n,d)=log((log(t)+1)/d);

using Distributed;
using DataFrames;


typeDistribution = "Gaussian"
niter = 32;  # Number of simulation runs to perform

if length(ARGS) > 1
    typeDistribution= ARGS[1];
    niter = parse.(Int64, ARGS[2]);
end

# Create default MAB instance
mu = [10, 9.4, 7, 6.5] # Eg 1
arm1 = Normal(mu[1],1);
arm2 = Normal(mu[2],1);
arm3 = Normal(mu[3],1);
arm4 = Normal(mu[4],1);

if typeDistribution == "Bernoulli"
    scale = maximum(mu)*1.2;
    mu = round.(mu./scale, digits = 2);
    arm1 = Bernoulli(mu[1]);
    arm2 = Bernoulli(mu[2]);
    arm3 = Bernoulli(mu[3]);
    arm4 = Bernoulli(mu[4]);
end

global MAB = (arm1, arm2, arm3, arm4);
println(MAB);

K = length(mu);
mu = mean.(MAB);
mu = [m for m in mu]
best = argmax(mu);

# Set problem parameters
α_local = 0.05
δ = 0.001;
T = 3000; # Max time for which the dynamics should evolve
seed = 1123;
βs = [0.5]

println("mu=$(mu), δ=$δ");

# Create table with data to store. 
# Define columns
Alg = Vector{String}();
MeanST = Vector{Float64}();
Sd_ST = Vector{Float64}();
RunTime = Vector{Float64}();
Sd_RT = Vector{Float64}();

# BETA-EB-TCB
mean_st_BetaEBTCB = Vector{Float64}();
std_st_BetaEBTCB = Vector{Float64}();

for a in 1:length(βs)
    println("\n ~~~", βs[a], "-EB-TCB  for 32 iterations ~~~");
    data_BetaEBTCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCB"; check_stop = true, b=beta, β = βs[a], α = α_local), 1:32);
    m = mean(length.(getindex.(getindex.(data_BetaEBTCB,1),1)));
    s = sqrt(var(length.(getindex.(getindex.(data_BetaEBTCB,1),1)))/niter);
    run_time_l = getindex.(getindex.(data_BetaEBTCB,1),5);
    mean_rt_l = mean(run_time_l);
    std_rt_l = sqrt(var(run_time_l)/niter);

    push!(mean_st_BetaEBTCB, round(m, digits = 2));
    push!(std_st_BetaEBTCB, round(s, digits = 2));
    push!(Alg, string(βs[a],"-EB-TCB"));
    push!(MeanST, round(m, digits = 2));
    push!(Sd_ST, round(s, digits = 2));
    push!(RunTime, round(mean_rt_l, digits = 2));
    push!(Sd_RT, round(std_rt_l, digits = 2));
end

# AT2
println("\n ~~~ AT2 ~~~")
data_AT2 = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2"; check_stop = true, b=beta, α = α_local, dist = typeDistribution), 1:niter);
stop_time_AT2 = length.(getindex.(getindex.(data_AT2,1),1));
mean_st_AT2_local = mean(stop_time_AT2);
std_st_AT2_local = sqrt(var(stop_time_AT2)/niter);
estimated_BA_AT2 = getindex.(getindex.(data_AT2,1),4);
error_frac_AT2 = sum(estimated_BA_AT2 .!= best)/T;
run_time_AT2 = getindex.(getindex.(data_AT2,1),5);
mean_rt_AT2 = mean(run_time_AT2);
std_rt_AT2 = sqrt(var(run_time_AT2)/niter);

index_AT2 = Vector{Vector{Float64}}();
for l = 1:K
    push!(index_AT2, getindex.(getindex.(getindex.(data_AT2,1),2)[1],l));
end
push!(Alg, string("AT2"));
push!(MeanST, round(mean_st_AT2_local, digits = 2));
push!(Sd_ST, round(std_st_AT2_local, digits = 2));
push!(RunTime, round(mean_rt_AT2, digits = 2));
push!(Sd_RT, round(std_rt_AT2, digits = 2));

# IAT2
println("\n ~~~ IAT2 ~~~")
data_AT2I =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2I"; check_stop = true, b=beta, α = α_local, dist = typeDistribution), 1:niter);
mean_st_AT2I_local = mean(length.(getindex.(getindex.(data_AT2I,1),1)));
std_st_AT2I_local = sqrt(var(length.(getindex.(getindex.(data_AT2I,1),1)))/niter);
estimated_BA_AT2I = getindex.(getindex.(data_AT2I,1),4);
error_frac_AT2I = sum(estimated_BA_AT2I .!= best)/T;
run_time_AT2I = getindex.(getindex.(data_AT2I,1),5);
mean_rt_AT2I = mean(run_time_AT2I);
std_rt_AT2I = sqrt(var(run_time_AT2I)/niter);

push!(Alg, "AT2I");
push!(MeanST, round(mean_st_AT2I_local, digits = 2));
push!(Sd_ST, round(std_st_AT2I_local, digits = 2));
push!(RunTime, round(mean_rt_AT2I, digits = 2));
push!(Sd_RT, round(std_rt_AT2I, digits = 2));

# TCB
println("\n ~~~ TCB ~~~")
data_TCB =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCB"; check_stop = true, b=beta, α = α_local, dist = typeDistribution), 1:niter);
mean_st_TCB_local = mean(length.(getindex.(getindex.(data_TCB,1),1)));
std_st_TCB_local = sqrt(var(length.(getindex.(getindex.(data_TCB,1),1)))/niter);
estimated_BA_TCB = getindex.(getindex.(data_TCB,1),4);
error_frac_TCB = sum(estimated_BA_TCB .!= best)/T;
run_time_TCB = getindex.(getindex.(data_TCB,1),5);
mean_rt_TCB = mean(run_time_TCB);
std_rt_TCB = sqrt(var(run_time_TCB)/niter);

push!(Alg, "TCB");
push!(MeanST, round(mean_st_TCB_local, digits = 2));
push!(Sd_ST, round(std_st_TCB_local, digits = 2));
push!(RunTime, round(mean_rt_TCB, digits = 2));
push!(Sd_RT, round(std_rt_TCB, digits = 2));


# TCBI
println("\n ~~~ TCBI ~~~")
data_TCBI =  pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCBI"; check_stop = true, b=beta, α = α_local, dist = typeDistribution), 1:niter);
mean_st_TCBI_local = mean(length.(getindex.(getindex.(data_TCBI,1),1)));
std_st_TCBI_local = sqrt(var(length.(getindex.(getindex.(data_TCBI,1),1)))/niter);
estimated_BA_TCBI = getindex.(getindex.(data_TCBI,1),4);
error_frac_TCBI = sum(estimated_BA_TCBI .!= best)/T;
run_time_TCBI = getindex.(getindex.(data_TCBI,1),5);
mean_rt_TCBI = mean(run_time_TCBI);
std_rt_TCBI = sqrt(var(run_time_TCBI)/niter);

push!(Alg, "TCBI");
push!(MeanST, round(mean_st_TCBI_local, digits = 2));
push!(Sd_ST, round(std_st_TCBI_local, digits = 2));
push!(RunTime, round(mean_rt_TCBI, digits = 2));
push!(Sd_RT, round(std_rt_TCBI, digits = 2));


################## PRINT DATA ##################
df = DataFrame(;AlgoName=Alg, AvgST = MeanST, StdDevST = Sd_ST, RunTimeMicSec=RunTime, StdDevRT = Sd_RT);
print(df);
println(" ");