# If not launching julia with multiple threads, remove @everywhere command from the below lines
@everywhere include("Top2Algos.jl")
@everywhere beta(t,n,d)=log((log(t)+1)/d);

using Distributed;
using DataFrames;

# Set default params
dist = "Bdd"
niter = 32;  # Number of simulation runs to perform
trunc_digits = 2; # Number of digits after dec. 
T = 500; # Time for which the dynamics should evolve

if length(ARGS) > 1
    dist= ARGS[1];
    niter = parse.(Int64, ARGS[2]);
    T = parse.(Int64, ARGS[3]);
end

# Create default MAB instance
arm1 = Beta(1.5,1);
arm2 = Beta(2,6);
arm3 = Beta(1,1.5);
arm4 = Beta(1,7);

global MAB = (arm1, arm2, arm3, arm4);
println(MAB);

mu = mean.(MAB);
mu = [m for m in mu]
K = length(mu);
a_star = argmax(mu);

# Set problem parameters
δ = 0.0001;
N = [1, 1, 1, 1]; # Initial number of samples to each arm

seed = 1123;
quant = 0.001;

# Collect data

# There is no randomness in fluid system, so running just 1 iteration, and no exploration needed
data_fluid_exp = pmap(i -> fluid_top2(MAB, δ, T, quant, false, seed + i; check_stop=false, InitialSamples = N), 1:1); 

# Running niter iterations of non-fluid AT2 algorithm
data_non_fluid = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "AT2"; check_stop=false, b=-1, InitialSamples = N, β = 0.5, α=0.5, αTCB=1, dist = dist), 1:niter);

# Running niter iterations of non-fluid TCB algorithm
data_TCB = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "TCB"; check_stop=false, b=-1, InitialSamples = N, β = 0.5, α=0.5, αTCB=1, dist = dist), 1:niter);

# Running niter iterations of non-fluid BetaEBTCB algorithm
data_BetaEBTCB = pmap(i-> non_fluid_top2_algorithms_np(MAB, δ, T, seed + i, "BetaEBTCB"; check_stop=false, b=-1, InitialSamples = N, β = 0.5, α=0.5, αTCB=1, dist = dist), 1:niter);


# Visualize the collected data
# collect return values: 
# 1 - number of samples till time t to each arm, 
# 2 - index for each arm at time t, 
# 3 - sum_ratio - 1 at time t.

using Plots;
I=1 # Pick a sample path to plot
save_dir = "output/"

################## CLEAN DATA FOR ALGOS. ##################

# Get time-series of index values of each arm in non-fluid AT2 across niter runs
# Compute time-series of average and std. of index values of each arm, # of samples for each arm, in non-fluid AT2
index_nfl_exp = Vector{Vector{Float64}}();
mean_index = Vector{Vector{Float64}}();
std_index = Vector{Vector{Float64}}();
samples_nfl = Vector{Vector{Int64}}();
mean_samples= Vector{Vector{Float64}}();
std_samples = Vector{Vector{Float64}}();

for l = 1:K
    push!(index_nfl_exp, getindex.(getindex.(getindex.(data_non_fluid,1),2)[I],l)); # time series of index values of arm l in fixed samplepath I of the experiment
    push!(mean_index, round.(mean([getindex.(getindex.(getindex.(data_non_fluid,1),2)[i],l) for i in 1:niter]), digits=trunc_digits)); #time series of mean (across niter iterations) index of arm l
    push!(std_index, round.(sqrt.(var([getindex.(getindex.(getindex.(data_non_fluid,1),2)[i],l) for i in 1:niter])./niter), digits=trunc_digits)); #time series of std. dev of (across niter iterations) index of arm l
    push!(samples_nfl, getindex.(getindex.(getindex.(data_non_fluid,1),1)[I],l)); #nsamples per arm
    push!(mean_samples, round.(mean([getindex.(getindex.(getindex.(data_non_fluid,1),1)[i],l) for i in 1:niter]), digits=trunc_digits)); #mean nsamples per a.
    push!(std_samples, round.(sqrt.(var([getindex.(getindex.(getindex.(data_non_fluid,1),1)[i],l) for i in 1:niter])./niter), digits=trunc_digits));
end

# get time series of mean and std. ratio-sum - 1 value in non-fluid AT2 #
mean_g_nfl = round.(mean([getindex.(getindex.(data_non_fluid,1),3)[i] for i in 1:niter]), digits = trunc_digits);
std_g_nfl = round.(sqrt.(var([getindex.(getindex.(data_non_fluid,1),3)[i] for i in 1:niter])./niter), digits = trunc_digits);


# COLLECT ALL ABOVE FOR OTHER ALGORITHMS: TCB
# Compute time-series of average and std. of index values of each arm, # of samples for each arm, in non-fluid TCB
mean_index_TCB = Vector{Vector{Float64}}();
std_index_TCB = Vector{Vector{Float64}}();
for l = 1:K
    push!(mean_index_TCB, round.(mean([getindex.(getindex.(getindex.(data_TCB,1),2)[i],l) for i in 1:niter]), digits=trunc_digits));
    push!(std_index_TCB, round.(sqrt.(var([getindex.(getindex.(getindex.(data_TCB,1),2)[i],l) for i in 1:niter])./niter), digits=trunc_digits));
end

# get time series of mean and std. ratio-sum - 1 value in non-fluid AT2
mean_g_TCB = round.(mean([getindex.(getindex.(data_TCB,1),3)[i] for i in 1:niter]), digits = trunc_digits);
std_g_TCB = round.(sqrt.(var([getindex.(getindex.(data_TCB,1),3)[i] for i in 1:niter])./niter), digits = trunc_digits);


# COLLECT ALL ABOVE FOR OTHER ALGORITHMS: BetaEBTCB
# Compute time-series of average and std. of index values of each arm, # of samples for each arm, in non-fluid AT2
mean_index_BetaEBTCB = Vector{Vector{Float64}}();
std_index_BetaEBTCB = Vector{Vector{Float64}}();
for l = 1:K
    push!(mean_index_BetaEBTCB, round.(mean([getindex.(getindex.(getindex.(data_BetaEBTCB,1),2)[i],l) for i in 1:niter]), digits=trunc_digits));
    push!(std_index_BetaEBTCB, round.(sqrt.(var([getindex.(getindex.(getindex.(data_BetaEBTCB,1),2)[i],l) for i in 1:niter])./niter), digits=trunc_digits));
end

# get time series of mean and std. ratio-sum - 1 value in non-fluid AT2
mean_g_BetaEBTCB = round.(mean([getindex.(getindex.(data_BetaEBTCB,1),3)[i] for i in 1:niter]), digits = trunc_digits);
std_g_BetaEBTCB = round.(sqrt.(var([getindex.(getindex.(data_BetaEBTCB,1),3)[i] for i in 1:niter])./niter), digits = trunc_digits);

################## ALGOS COMPLETED ##################

#=
################## CLEAN DATA FOR FLUID ##################
# Get time-series of index values, # of samples of each arm in fluid
index_fl_exp = Vector{Vector{Float64}}();
samples_fl_exp = Vector{Vector{Float64}}();
for l in 1:K
    push!(index_fl_exp,getindex.(getindex.(getindex.(data_fluid_exp,1),2)[I],l));
    push!(samples_fl_exp,getindex.(getindex.(getindex.(data_fluid_exp,1),1)[I],l));
end

samples1_fl_exp = samples_fl_exp[1];
samples2_fl_exp = samples_fl_exp[2];
samples3_fl_exp = samples_fl_exp[3];
samples4_fl_exp = samples_fl_exp[4];

# get time series of ratio-sum - 1 value in fluid
g_fl_exp = getindex.(getindex.(data_fluid_exp,1),3)[I];

# Since fluid allocates quant samples at each time step, 
# 1 time step in algo corresponds to 1/quant time step in fluid
# We need to sub-sample observations at these points for fluid and store

# create vectors to store these sub-sampled values for fluid
g_fl_exp_subsampled = zeros(length(index2_nfl));
samples1_fl_exp_subsampled = zeros(length(index2_nfl));

index2_fl_exp_subsampled = zeros(length(index2_nfl));
samples2_fl_exp_subsampled = zeros(length(index2_nfl));

index3_fl_exp_subsampled = zeros(length(index2_nfl));
samples3_fl_exp_subsampled = zeros(length(index2_nfl));

index4_fl_exp_subsampled = zeros(length(index2_nfl));
samples4_fl_exp_subsampled = zeros(length(index2_nfl));

# Compute the number of time-steps to plot
iter_bound = Int(floor(min(length(index2_fl_exp) * quant, length(index2_nfl))));

for iter_index in 1:iter_bound
    g_fl_exp_subsampled[iter_index] = g_fl_exp[Int(iter_index/quant)];    
    samples1_fl_exp_subsampled[iter_index] = samples1_fl_exp[Int(iter_index/quant)]*quant;
    index2_fl_exp_subsampled[iter_index] = index2_fl_exp[Int(iter_index/quant)];
    samples2_fl_exp_subsampled[iter_index] = samples2_fl_exp[Int(iter_index/quant)]*quant;
    index3_fl_exp_subsampled[iter_index] = index3_fl_exp[Int(iter_index/quant)];
    samples3_fl_exp_subsampled[iter_index] = samples3_fl_exp[Int(iter_index/quant)]*quant;
    index4_fl_exp_subsampled[iter_index] = index4_fl_exp[Int(iter_index/quant)];
    samples4_fl_exp_subsampled[iter_index] = samples4_fl_exp[Int(iter_index/quant)]*quant;
    local iter_index += 1;
end

################## FLUID COMPLETED ##################
=#


################## FINALLY, GENERATE PLOTS ##################

colors = [ :red, :blue, :orange, :green, :purple, :black, :brown, :yellow, :cyan, :pink]

######### non-fluid AT2 indexes #########
p_alg_exp = plot(title = "AT2 Index Values", xlabel="Time", ylabel="Index Values");
for l in 1:K
    if l != a_star
        plot!(mean_index[l], ribbon =(2 .* std_index[l],2 .* std_index[l]), fillalpha= 0.2, color=colors[l], line=(:solid, 2), label = string("Arm ",l), xlabel="Time", ylabel="Index Values");
    end
end

savefig(p_alg_exp,string(save_dir, K, "AT2Indexes_",dist,"_",T,"_",niter,".pdf"));

######### non-fluid TCB indexes #########
p_TCB = plot(title = "TCB Index Values", xlabel="Time", ylabel="Index Values");
for l in 1:K
    if l != a_star
        plot!(mean_index_TCB[l], ribbon =(2 .* std_index_TCB[l],2 .* std_index_TCB[l]), fillalpha= 0.2, color=colors[l], line=(:dashdotdot, 2), label = string("Arm ",l), xlabel="Time", ylabel="Index Values");
    end
end

savefig(p_TCB,string(save_dir, K, "TCBIndexes_",dist,"_",T,"_",niter,".pdf"));

######### non-fluid BetaEBTCB indexes #########
p_BetaEBTCB = plot(title = "Beta-EB-TCB Index Values", xlabel="Time", ylabel="Index Values");
for l in 1:K
    if l != a_star
        plot!(mean_index_BetaEBTCB[l], ribbon =(2 .* std_index_BetaEBTCB[l],2 .* std_index_BetaEBTCB[l]), fillalpha= 0.2, color=colors[l], line=(:dashdotdot, 2), label = string("Arm ",l), xlabel="Time", ylabel="Index Values");
    end
end

savefig(p_BetaEBTCB,string(save_dir, K, "BetaEBTCBIndexes_",dist,"_",T,"_",niter,".pdf"));

######### Combined ratio-sum-1 condition for all algos #########
plot(mean_g_nfl, ribbon = (2 .* std_g_nfl, 2 .* std_g_nfl), fillalpha = 0.2, color=:green, line=(:dashdotdot,2),title="Anchor function value", label="AT2");

plot!(mean_g_TCB, ribbon = (2 .* std_g_TCB, 2 .* std_g_TCB), fillalpha = 0.2, color=:blue, line=(:dashdotdot,2),label="TCB");

g_comb = plot!(mean_g_BetaEBTCB, ribbon = (2 .* std_g_BetaEBTCB, 2 .* std_g_BetaEBTCB), fillalpha = 0.2, color=:red, line=(:dashdotdot,2),label="Beta-EB-TCB");

savefig(g_comb, string(save_dir, K, "Arms_Algos_combined_g_",dist,"_",T,"_",niter, ".pdf"));

g = plot(mean_g_nfl, ribbon = (2 .* std_g_nfl, 2 .* std_g_nfl), fillalpha = 0.2, color=:green, line=(:solid,2),title="Anchor function value", label="AT2");
savefig(g, string(save_dir, K, "AT2_g_",dist,"_",T,"_",niter, ".pdf"));


######### ADDITIONAL PLOTS, CURRENTLY NOT IN USE #########

#=

######### fluid indexes #########
p1 = plot([index2_fl_exp_subsampled index3_fl_exp_subsampled index4_fl_exp_subsampled], title="Indexes", xlabel="Time", ylabel="Index Values",  color=[:blue :orange :green], label = ["Arm 2: Fluid Limit" "Arm 3: Fluid Limit" "Arm 4: Fluid Limit"]);

savefig(p1, string(save_dir,"Indexes_Fluid_",T,".pdf"));

######### fluid and non-fluid AT2 indexes on same plot #########

plot([index2_fl_exp_subsampled index3_fl_exp_subsampled index4_fl_exp_subsampled], color=[:blue :orange :green], label = ["Arm 2: Fluid Limit" "Arm 3: Fluid Limit" "Arm 4: Fluid Limit"]);

plot!(mean_index2, ribbon =(2 .* std_index2,2 .* std_index2), fillalpha= 0.2, color=:blue, line=(:dashdotdot, 1), title = "Indexes", label = "Arm 2: Non-fluid", xlabel="Time", ylabel="Index Values");

plot!(mean_index3, ribbon =(2 .* std_index3,2 .* std_index3), fillalpha= 0.2, color=:orange, line=(:dashdotdot, 1), label = "Arm 3: Non-fluid");

p_comb_fl_exp = plot!(mean_index4, ribbon =(2 .* std_index4,2 .* std_index4), fillalpha= 0.2, color=:green, line=(:dashdotdot, 1), label = "Arm 4: Non-fluid");

savefig(p_comb_fl_exp,string(save_dir, "Indexes_",T,".pdf"));

######### fluid and non-fluid AT2 ratio-sum - 1 value on same plot #########

plot([g_fl_exp_subsampled], line=(:line, 1), color=:blue, xlabel="Time", ylabel="Anchor Function Value", title="Anchor Function Value" ,label="Fluid Limit");

p_g_combined = plot!(mean_g_nfl, ribbon = (2 .* std_g_nfl, 2 .* std_g_nfl), fillalpha = 0.2, color=:red, line=(:dashdotdot,1), label="Non Fluid");

savefig(p_g_combined, string(save_dir, "g_",T, ".pdf"));

######### fluid and non-fluid AT2 number of samples on same plot #########

plot([samples1_fl_exp_subsampled samples2_fl_exp_subsampled samples3_fl_exp_subsampled samples4_fl_exp_subsampled], color=[:red :blue :orange :green], label = ["Arm 1: Fluid Limit" "Arm 2: Fluid Limit" "Arm 3: Fluid Limit" "Arm 4: Fluid Limit"]);

plot!(samples_nfl[1],ribbon =(2 .* std_samples[1],2 .* std_samples[1]), fillalpha= 0.2,  color=:red, line=(:dashdotdot, 1), label = "Arm 1: Non-fluid");

plot!(samples_nfl[2],ribbon =(2 .* std_samples[2],2 .* std_samples[2]), fillalpha= 0.2,  color=:blue, line=(:dashdotdot, 1), label = "Arm 2: Non-fluid");

plot!(samples_nfl[3],ribbon =(2 .* std_samples[3],2 .* std_samples[3]), fillalpha= 0.2,  color=:orange, line=(:dashdotdot, 1), label = "Arm 3: Non-fluid");

p_comb_fl_exp_samples = plot!(samples_nfl[4],ribbon =(2 .* std_samples[4],2 .* std_samples[4]), fillalpha= 0.2,  color=:green, line=(:dashdotdot, 1), title = "Number of Samples", label = "Arm 4: Non-fluid", legend=:topleft, xlabel ="Time", ylabel="Number of Samples");

savefig(p_comb_fl_exp_samples,string(save_dir, "Samples_",T,".pdf"));

=#
