# Basic functions
# data collector for 1-dimensional arms
data_collector <- function(mu_vec,
                           T_max,
                           noise_generator,
                           sampling,
                           stopping,
                           num_initial = 1,
                           ...) {
  # Allocate memory.
  K <- length(mu_vec)
  N_vec <- rep(0, K)
  Y_vec <- rep(NA, T_max)
  A_vec <- rep(NA, T_max)
  mu_hat_vec <- rep(NA, K)
  
  # Initialization
  # Collect data from each arm num_initial times.
  for (i in seq(1, K * num_initial)) {
    mod <- i %% K
    k_ind <- ifelse(mod == 0, K, mod)
    Y_vec[i] <- mu_vec[k_ind] + noise_generator()
    A_vec[i] <- k_ind
    N_vec[k_ind] <- N_vec[k_ind] + 1
    if (i <= K) {
      mu_hat_vec[k_ind] <- Y_vec[i]
    } else {
      mu_hat_vec[k_ind] <-
        mu_hat_vec[k_ind] + (Y_vec[i] - mu_hat_vec[k_ind]) / N_vec[k_ind]
    }
  }
  
  # Collect data based on sampling rule.
  for (i in seq(K * num_initial + 1, T_max)) {
    A_t <- sampling(mu_hat_vec, N_vec, i, ...)
    A_vec[i] <- A_t
    Y_vec[i] <- mu_vec[A_t] + noise_generator()
    N_vec[A_t] <- N_vec[A_t] + 1
    mu_hat_vec[A_t] <-
      mu_hat_vec[A_t] + (Y_vec[i] - mu_hat_vec[A_t]) / N_vec[A_t]
    is_stop <- stopping(i, T_max, mu_hat_vec, N_vec, ...)
    if (is_stop)
      break
  }
  out <- list(
    Y = Y_vec,
    A = A_vec,
    mu_hat = mu_hat_vec,
    N = N_vec,
    true_mu = mu_vec
  )
  return(out)
}

simulator <- function(num_repeat,
                      collector,
                      mu_vec,
                      choosing = NULL,
                      arm_names = NULL,
                      is_progress_bar = TRUE,
                      choosing_base_mu = TRUE,
                      choosing_base_N = FALSE,
                      ...) {
  K <- length(mu_vec)
  
  # If arm_names is not specified or incorrectly given,
  # then we name each arm by its index.
  if (is.null((arm_names))) {
    arm_names <- paste("Arm", 1:K)
  } else {
    if (K != length(arm_names)) {
      warning(
        "Number of amr_names is not matched with the number of arms. Arms are named by their index"
      )
      arm_names <- paste("Arm", 1:k)
    }
  }
  
  # Allocate memory.
  mu_hat_mat <- matrix(NA, nrow = num_repeat, ncol = K)
  N_mat <- matrix(NA, nrow = num_repeat, ncol = K)
  
  # Print progress bar.
  pb <- progress::progress_bar$new(total = num_repeat)
  
  # Collect data.
  for (i in 1:num_repeat) {
    dat <- collector(mu_vec, ...)
    mu_hat_mat[i, ] <- dat$mu_hat
    N_mat[i, ] <- dat$N
    if (is_progress_bar)
      pb$tick()
  }
  
  # Get observed difference between sample mean and true mean
  obs_bias <- data.frame(sweep(mu_hat_mat, 2, mu_vec))
  
  # Calculate bias, MSE and effective sample size (E[1/N])^{-1}.
  bias <- colMeans(obs_bias)
  sd <- sapply(obs_bias, sd)
  MSE <- colMeans(obs_bias ^ 2)
  eff_N <- 1 / colMeans(1 / N_mat)
  expect_N <- colMeans(N_mat)

  
  if (!is.null(choosing)) {
    # Choose target mu hat.
    if (choosing_base_mu) {
      choice_vec <- apply(mu_hat_mat, 1, choosing)
    } else if (choosing_base_N) {
      choice_vec <- apply(N_mat, 1, choosing)
    }
    ind <- cbind(seq(choice_vec), choice_vec)
    
    # Get difference between chosen sample mean and true mean.
    obs_bias$Chosen <- mu_hat_mat[ind] - mu_vec[choice_vec]
    N_chosen <- N_mat[ind]
    
    # Calculate bias, MSE and effective sample size of chosen arm.
    bias[K + 1] <- mean(obs_bias$Chosen)
    sd[K + 1] <- sd(obs_bias$Chosen)
    MSE[K + 1] <- mean(obs_bias$Chosen ^ 2)
    eff_N[K + 1] <- 1 / mean(1 / N_chosen)
    expect_N[K + 1] <- mean(N_chosen)
    arm_names[K + 1] <- "Chosen"
  }
  
  # Make a summary table
  MC_error <- 2*sd / sqrt(num_repeat) # Monte Carlo Simulation error (2 * sd / sqrt(T))
  summ <- data.frame(
    bias = bias,
    sd = sd,
    MSE = MSE,
    eff_N = eff_N,
    expect_N = expect_N,
    MC_error = MC_error
  )
  
  # Put arm names on the result.
  colnames(obs_bias) <- arm_names
  rownames(summ) <- arm_names
  
  
  return(list(
    mu_hat = mu_hat_mat,
    N = N_mat,
    true_mu = mu_vec,
    obs_bias = obs_bias,
    summ = summ
  ))
}

plot_simulation <- function(simul_result,
                            title,
                            arm_names = NULL,
                            case_name = "Case",
                            show_subtitle = TRUE,
                            is_plot = TRUE) {
  # Get the simulation result.
  obs_bias <- simul_result$obs_bias
  summ <- simul_result$summ
  
  if (!is.null(arm_names)){
    colnames(obs_bias) <- arm_names
    rownames(summ) <- arm_names
  }
  
  # Convert into long format.
  obs_bias_long <- reshape2::melt(
    obs_bias,
    id.vars = NULL,
    variable.name = case_name,
    value.name = "obs_bias"
  )
  
  bias_long <- data.frame(arms = rownames(summ),
                          bias = summ$bias)
  
  # Make a ggplot object.
  g <-
    ggplot(obs_bias_long, aes(obs_bias, fill = get(case_name), 
                              colour = get(case_name))) +
    geom_density(alpha = 0.1) +
    geom_vline(data = bias_long ,
               aes(xintercept = bias, colour = arms),
               linetype = "dashed") +
    geom_vline(xintercept = 0, size = 0.5) +
    guides(fill=guide_legend(title=case_name),
           colour=guide_legend(title=case_name))+ 
    theme(
      # Remove panel border
      panel.border = element_blank(),
      # Remove panel grid lines
      panel.grid.major = element_blank(),
      panel.grid.minor = element_blank(),
      # Remove panel background
      panel.background = element_blank(),
      # Add axis line
      axis.line = element_line(colour = "grey")) +
      # Adjust y axis
    scale_y_continuous(expand = c(0, 0))
      
      # If show_subtitle == TRUE, add summary of the result as a subtitle.
      if (show_subtitle) {
        bias_char <- paste(round(summ$bias, 3), collapse = ", ")
        eff_N_char <- paste(round(summ$eff_N, 3), collapse = ", ")
        subtitle <- paste0("Bias = (", bias_char, ")")
        g  <-
          g + labs(x = "difference between sample and true means",
                   title = title,
                   subtitle = subtitle)
      } else {
        g <-
          g + labs(x = "difference between sample and true means", title = title)
      }
      
      # If is_plot == TRUE, print the plot
      if (is_plot)
        print(g)
      
      return(g)
}

# noise generators
normal_noise <- function(sig = 1)
  rnorm(sig)

# Sampling
deterministic_sampling <- function(mu_hat_vec, N_vec, t, ...) {
  K <- length(mu_hat_vec)
  A_t <- t %% K
  if (A_t == 0)
    A_t <- K
  return(A_t)
}

# Stopping functions
# Return TRUE once the stopping time is reached.
fixed_stopping <- function(t, T_max, ...)
  T_max <= t

SLRT_stopping <- function(t,
                          T_max,
                          mu_hat_vec,
                          N_vec,
                          alpha = 0.1,
                          M = 200,
                          sig = 1,
                          w = 10,
                          ...) {
  if (length(N_vec) != 2) {
    stop("SLRT works only for the two sampling testing problem.")
  }
  
  if (t >= M) {
    return(TRUE)
  }
  
  if (t %% 2 != 0) {
    return(FALSE)
  } else {
    if (N_vec[1] != N_vec[2]) {
      stop("SLRT works only for the balanced design")
    }
    sqrt_inner <-
      (t + 2 * w) * log(sqrt((t + 2 * w) / 2 * w) / (2 * alpha) + 1)
    thres <- 2 * sig / t * sqrt(sqrt_inner)
    check <- mu_hat_vec[1] - mu_hat_vec[2] > tres
    return(check)
  }
}

# Choosing functions
best_arm_choosing <- function(mu_hat_vec) {
  which.max(mu_hat_vec)
}


# Convert list of simulation results into results of chosen arms.
results_for_chosen_arms <-
  function(results_list, chosen_names = NULL) {
    K <- length(results_list)
    if (K == 1)
      return(results_list)
    
    if (is.null(chosen_names)) {
      chosen_names <- paste("Chosen", 1:K)
    } else {
      if (K != length(chosen_names)) {
        warning(
          "Number of chosen_names is not matched with the number of arms. Arms are named by their index"
        )
        chosen_names <- paste("Chosen", 1:k)
      }
    }
    
    obs_bias <- list()
    N_mat <- list()
    for (i in seq(1, K)) {
      obs_bias[[i]] <- results_list[[i]]$obs_bias$Chosen
      N_mat[[i]] <- results_list[[i]]$summ["Chosen", ]
    }
    
    obs_bias <- data.frame(obs_bias)
    colnames(obs_bias) <- chosen_names
    
    names(N_mat) <- chosen_names
    summ <- do.call(rbind, N_mat)
    
    return(list(obs_bias = obs_bias,
                summ = summ))
  }

# ETC

# Multiple plot function
# Ref : http://www.cookbook-r.comGraphs/Multiple_graphs_on_one_page_(ggplot2)/
# ggplot objects can be passed in ..., or to plotlist (as a list of ggplot objects)
# - cols:   Number of columns in layout
# - layout: A matrix specifying the layout. If present, 'cols' is ignored.
#
# If the layout is something like matrix(c(1,2,3,3), nrow=2, byrow=TRUE),
# then plot 1 will go in the upper left, 2 will go in the upper right, and
# 3 will go all the way across the bottom.
#
multiplot <-
  function(...,
           plotlist = NULL,
           file,
           cols = 1,
           layout = NULL) {
    # Make a list from the ... arguments and plotlist
    plots <- c(list(...), plotlist)
    
    numPlots = length(plots)
    
    # If layout is NULL, then use 'cols' to determine layout
    if (is.null(layout)) {
      # Make the panel
      # ncol: Number of columns of plots
      # nrow: Number of rows needed, calculated from # of cols
      layout <- matrix(seq(1, cols * ceiling(numPlots / cols)),
                       ncol = cols,
                       nrow = ceiling(numPlots / cols))
    }
    
    if (numPlots == 1) {
      print(plots[[1]])
      
    } else {
      # Set up the page
      grid.newpage()
      pushViewport(viewport(layout = grid.layout(nrow(layout), ncol(layout))))
      
      # Make each plot, in the correct location
      for (i in 1:numPlots) {
        # Get the i,j matrix positions of the regions that contain this subplot
        matchidx <- as.data.frame(which(layout == i, arr.ind = TRUE))
        
        print(plots[[i]],
              vp = viewport(
                layout.pos.row = matchidx$row,
                layout.pos.col = matchidx$col
              ))
      }
    }
  }

# Extract common limits from list of ggplots
get_common_range <- function(list_plots) {
  K <- length(list_plots)
  y_range <- matrix(NA, nrow = K, ncol = 2)
  x_range <- matrix(NA, nrow = K, ncol = 2)
  for (i in 1:K) {
    y_range[i,] <- layer_scales(list_plots[[i]])$y$range$range
    x_range[i,] <- layer_scales(list_plots[[i]])$x$range$range
  }
  y_range_minmax <- c(min(y_range[, 1]), max(y_range[, 2]))
  x_range_minmax <- c(min(x_range[, 1]), max(x_range[, 2]))
  return(list(x_range = x_range_minmax,
              y_range = y_range_minmax))
}
