# Rscript to generate figures
# in "The bias of the sample mean in multi-armed bandits can be positive or negative"

# Install packages if they are not installed.
# install.packages(c("progress", "ggplot2", "reshape2", "grid", "latex2exp"))

# Load libraries and basic functions
# Need to set the directory of this script as a working directory.
# Also make sure there exists ''figs'' folder under the working directory.

library(progress)
library(ggplot2)
library(reshape2)
library(grid)
library(latex2exp)

source("./adap_data_collectors.R")

# Set a random seed.
set.seed(1)

# Sampling functions
greedy_sampling <- function(mu_hat_vec, N_vec, t, ...) {
  A_t <- which.max(mu_hat_vec)
  return(A_t)
}

UCB_sampling <- function(mu_hat_vec, N_vec, t,
                         delta = 0.1, ...) {
  const <- 2 * log(1 / delta)
  u_t <- function(n) {
    return(sqrt(const / n))
  }
  UCB <- mu_hat_vec + sapply(N_vec, u_t)
  A_t <- which.max(UCB)
  return(A_t)
}

TS_sampling <- function(mu_hat_vec, N_vec, t,
                        sig_0 = 1, ...) {
  inv_sig <- 1 / sig_0 ^ 2
  mu_tilde <- N_vec * mu_hat_vec / (inv_sig + N_vec)
  sig_vec <- 1 / (inv_sig + N_vec)
  Z <- rnorm(1)
  TS <- mu_tilde + sig_vec * Z
  A_t <- which.max(TS)
  return(A_t)
}

# Simulation for the example 1
num_repeat <- 10000L
T_max <- 200L
mu_vec <- c(1, 2, 3)


sampling_list <- list(greedy = greedy_sampling,
                      UCB = UCB_sampling,
                      TS = TS_sampling)

simul_1_results <- list(greedy = list(),
                        UCB = list(),
                        TS = list())

for (i in seq_along(sampling_list)) {
  collector_normal_fixed <- function(mu_vec, ...) {
    data_collector(
      mu_vec,
      T_max,
      noise_generator = normal_noise,
      sampling = sampling_list[[i]],
      stopping = fixed_stopping,
      ...
    )
  }
  simul_1_results[[i]] <- simulator(num_repeat,
                                    collector_normal_fixed,
                                    mu_vec,
                                    arm_names = c("N(1,1)", "N(2,1)", "N(3,1)"))
}

##### (a) Greedy algorithm.
pdf("./figs/simul_1_greedy.pdf",
    width = 6,
    height = 4)
p1 <- plot_simulation(simul_1_results$greedy, 
                      title = "Greedy algorithm",
                      case_name = "Arms")
dev.off()

##### (b) UCB. ($\delta = 0.1$)
pdf("./figs/simul_1_UCB.pdf",
    width = 6,
    height = 4)
p2 <- plot_simulation(simul_1_results$UCB, 
                      title = "UCB algorithm",
                      case_name = "Arms")
dev.off()

##### (C) Thompson Sampling. (standard normal prior)
pdf("./figs/simul_1_TS.pdf", width = 6, height = 4)
p3 <-
  plot_simulation(simul_1_results$TS, title = "Thompson sampling",
                  case_name = "Arms")
dev.off()

#### Together
ranges <- get_common_range(list(p2, p3))
p2 <- p2 + theme(legend.position = c(0.2, 0.75)) +
  scale_x_continuous(limits = ranges$x_range) +
  scale_y_continuous(expand = c(0, 0),
                     limits = ranges$y_range)
p3 <- p3 + theme(legend.position = c(0.2, 0.75)) +
  scale_x_continuous(limits = ranges$x_range) +
  scale_y_continuous(expand = c(0, 0),
                     limits = ranges$y_range)

pdf("./figs/simul_1_all.pdf",
    width = 9,
    height = 4)
multiplot(p2, p3, cols = 2)
dev.off()


# SLRT Stopping function
# Return TRUE once the stopping time is reached.
SLRT_stopping <- function(t,
                          T_max,
                          mu_hat_vec,
                          N_vec,
                          alpha = 0.1,
                          M = 200,
                          sig = 1,
                          w = 10,
                          ...) {
  if (length(N_vec) != 2) {
    stop("SLRT works only for the two sampling testing problem.")
  }
  
  if (t >= M) {
    return(TRUE)
  }
  
  if (t %% 2 != 0) {
    return(FALSE)
  } else {
    if (N_vec[1] != N_vec[2]) {
      stop("SLRT works only for the balanced design")
    }
    sqrt_inner <-
      (t + 2 * w) * log(sqrt((t + 2 * w) / 2 * w) / (2 * alpha) + 1)
    thres <- 2 * sig / t * sqrt(sqrt_inner)
    check <- mu_hat_vec[1] - mu_hat_vec[2] > thres
    return(check)
  }
}

# Simulation for the example 2
num_repeat <- 10000L

mu_vec_list <- list(H0 = c(0, 0),
                    H1 = c(1, 0))

simul_2_results <- list(H0 = list(),
                        H1 = list())

for (i in seq_along(mu_vec_list)) {
  collector_normal_fixed <- function(mu_vec, ...) {
    data_collector(
      mu_vec,
      T_max,
      noise_generator = normal_noise,
      sampling = deterministic_sampling,
      stopping = SLRT_stopping,
      ...
    )
  }
  simul_2_results[[i]] <- simulator(num_repeat,
                                    collector_normal_fixed,
                                    mu_vec_list[[i]])
}

##### (a) $\mu_1 = \mu_2 = 0$.
pdf("./figs/simul_2_null.pdf",
    width = 6,
    height = 4)
simul2_p1 <-
  plot_simulation(simul_2_results$H0, 
                  title = "Under the null hypothesis",
                  arm_names = c("mu1 = 0", "mu2 = 0"),
                  case_name = "Arms")
dev.off()


pdf("./figs/simul_2_null_zoomed.pdf",
    width = 24,
    height = 16)
simul2_p1 <-
  plot_simulation(simul_2_results$H0, 
                  title = "Under the null hypothesis",
                  arm_names = c("mu1 = 0", "mu2 = 0"),
                  case_name = "Arms")
dev.off()


##### (b) $\mu_1 =1, \mu_2 = 0$.
pdf("./figs/simul_2_alternative.pdf",
    width = 6,
    height = 4)
simul2_p2 <-
  plot_simulation(simul_2_results$H1, title = "Under the alternative hypothesis",
                  arm_names = c("mu1 = 1", "mu2 = 0"),
                  case_name = "Arms")
dev.off()

#### Together
ranges <- get_common_range(list(simul2_p1, simul2_p2))
ranges$x_range <- ranges$x_range
simul2_p1 <- simul2_p1 + theme(legend.position = c(0.2, 0.75)) +
  scale_x_continuous(limits = ranges$x_range) +
  scale_y_continuous(expand = c(0, 0),
                     limits = ranges$y_range)
simul2_p2 <- simul2_p2 + theme(legend.position = c(0.2, 0.75)) +
  scale_x_continuous(limits = ranges$x_range) +
  scale_y_continuous(expand = c(0, 0),
                     limits = ranges$y_range)

pdf("./figs/simul_2_all.pdf",
    width = 9,
    height = 4)
multiplot(simul2_p1, simul2_p2, cols = 2)
dev.off()


# Finding largest mean stopping function
# Return TRUE once the stopping time is reached.
finding_largest_stopping <- function(t,
                                     T_max,
                                     mu_hat_vec,
                                     N_vec,
                                     delta = 0.1,
                                     M = 200) {
  K <- length(mu_hat_vec)
  if (t %% K == 0) {
    mu_hat_sorted <- sort(mu_hat_vec, decreasing = TRUE)
    chk <- mu_hat_sorted[1] - mu_hat_sorted[2] > delta
    return(chk)
  } else {
    return(FALSE)
  }
}

# Simulation for the example 3
num_repeat <- 10000L
T_max <- 3000L
gaps <- c(1, 3, 5)
mu_vec_list <- list(
  gap1 = c(gaps[1], 0, -gaps[1]),
  gap2 = c(gaps[2], 0,-gaps[2]),
  gap3 = c(gaps[3], 0,-gaps[3])
)
delta_ratio <- 0.8

simul_3_results <- list(gap1 = list(),
                        gap2 = list(),
                        gap3 = list())

for (i in seq_along(mu_vec_list)) {
  stopping_delta <- function(t, T_max,
                             mu_hat_vec, N_vec) {
    finding_largest_stopping(t,
                             T_max,
                             mu_hat_vec,
                             N_vec,
                             delta = gaps[i] * delta_ratio,
                             M = T_max)
  }
  collector_normal_fixed <- function(mu_vec, ...) {
    data_collector(
      mu_vec,
      T_max,
      noise_generator = normal_noise,
      sampling = deterministic_sampling,
      stopping = stopping_delta,
      ...
    )
  }
  simul_3_results[[i]] <- simulator(num_repeat,
                                    collector_normal_fixed,
                                    mu_vec_list[[i]],
                                    choosing = best_arm_choosing)
}

chosen_names <-
  paste0(c("Small gap (g=", "Medium gap (g=", "Large gap (g="), gaps, rep(")", 3))
simul_3_chosen <- results_for_chosen_arms(simul_3_results,
                                          chosen_names = chosen_names)

pdf("./figs/simul_3.pdf", width = 6, height = 4)
plot_simulation(simul_3_chosen, title = "Estimate largest mean")
dev.off()


# lil UCB simulation

# sampling
cal_lil_UCB_delta <- function(nu, eps) {
  c_eps <- (2 + eps) / eps * (1 / log(1 + eps)) ^ (1 + eps)
  delta <- (sqrt(1 + nu / 2) - 1) ^ 2 / c_eps / 4
  return(delta)
}

lil_UCB_sampling <- function(mu_hat_vec,
                             N_vec,
                             t,
                             sig = 1,
                             delta = NULL,
                             nu = 0.1,
                             eps = 0.01,
                             beta = 1,
                             lambda = (2 + beta) ^ 2 / beta ^ 2,
                             ...) {
  if (is.null(delta)) {
    delta <- cal_lil_UCB_delta(nu, eps)
  }
  
  log_const <- log((1 + eps) * N_vec) / delta
  const <- 2 * sig ^ 2 * (1 + eps) * log(log_const) / N_vec
  lil_UCB <-
    mu_hat_vec + (1 + beta) * (1 + sqrt(eps)) * sqrt(const)
  A_t <- which.max(lil_UCB)
  return(A_t)
}

# stopping
# Return TRUE once the stopping time is reached.
lil_UCB_stopping <- function(t,
                             T_max,
                             mu_hat_vec,
                             N_vec,
                             beta = 1,
                             lambda = (2 + beta) ^ 2 / beta ^ 2,
                             ...) {
  max_T_ind <- which.max(N_vec)
  check <-
    N_vec[max_T_ind] >= (1 + lambda * (sum(N_vec) - N_vec[max_T_ind]))
  return(check)
}

# Choosing functions
lil_UCB_choosing <- function(N_vec) {
  which.max(N_vec)
}

# Simulation for the lil-UCB example
num_repeat <- 10000L
T_max <- 5000L
gaps <- c(1, 3, 5)
mu_vec_list <- list(
  gap1 = c(gaps[1], 0, -gaps[1]),
  gap2 = c(gaps[2], 0,-gaps[2]),
  gap3 = c(gaps[3], 0,-gaps[3])
)
nu <- 0.1

simul_4_results <- list(gap1 = list(),
                        gap2 = list(),
                        gap3 = list())

for (i in seq_along(mu_vec_list)) {
  sampling_nu <- function(mu_hat_vec, N_vec, t) {
    lil_UCB_sampling(mu_hat_vec,
                     N_vec,
                     t,
                     sig = 1,
                     delta = NULL,
                     nu = nu)
  }
  collector_normal_lil_UCB <- function(mu_vec, ...) {
    data_collector(
      mu_vec,
      T_max,
      noise_generator = normal_noise,
      sampling = lil_UCB_sampling,
      stopping = lil_UCB_stopping ,
      ...
    )
  }
  simul_4_results[[i]] <- simulator(
    num_repeat,
    collector_normal_lil_UCB,
    mu_vec_list[[i]],
    choosing = lil_UCB_choosing,
    choosing_base_mu = FALSE,
    choosing_base_N = TRUE
  )
  K <- length(mu_vec_list[[i]])
  tab <- table(apply(simul_4_results[[i]]$N, 1, sum))
  chk <- any(as.character(c(T_max, K)) %in% names(tab))
  if (chk)
    stop(paste("T_max is too small at iteration", nu_vec[i]),
         " Need to increase T_max or mu gaps.")
}

chosen_names <-
  paste0(c("Small gap (g=", "Medium gap (g=", "Large gap (g="), gaps, rep(")", 3))
simul_4_chosen <- results_for_chosen_arms(simul_4_results,
                                          chosen_names = chosen_names)

pdf("./figs/simul_4.pdf", width = 6, height = 4)
plot_simulation(simul_4_chosen, title = "Estimate largest mean")
dev.off()

pdf("./figs/simul_4_zoomed.pdf", width = 24, height = 16)
plot_simulation(simul_4_chosen, title = "Estimate largest mean")
dev.off()
