rm(list = ls())
library(tidyr)
library(dplyr)
library(ggplot2)
library(gridExtra)
library(xtable)
library(rstudioapi)

# the following line is for getting the path of your current open file
current_path <- getActiveDocumentContext()$path 
# The next line set the working directory to the relevant one:
setwd(dirname(current_path))
# you can make sure you are in the right directory
print(getwd())

nClusters <- 10 # number of cluster in the network
clusterSize  <- 5000 # number of nodes
gammaSeq <- c(0, 0.5)
treatmentLevels <- sapply(gammaSeq, function(gamma) paste0("T(", gamma, ")"))

oasis <- NULL
clusterBased <- NULL
for (avgDegree in c(21, 55, 120)){
  dataDir <- paste0("data/networkSize=",clusterSize * nClusters, "/avgDegree=",avgDegree, "/")
  estimatePath <- paste0(dataDir, "estimates/")
  clusterBasedEstimatePath <- paste0(dataDir, "clusterBasedEstimates/")
  load(paste0(estimatePath, "allEstimates.Rdata"))
  load(paste0(clusterBasedEstimatePath, "allClusterBasedEstimates.Rdata"))
  oasis <- rbind(oasis, allEstimates %>% mutate(avgDegree = avgDegree))
  clusterBased <- rbind(clusterBased, allClusterBasedEstimates %>% mutate(avgDegree = avgDegree))
}

allClusterBasedEstimates <- clusterBased %>% filter(treatment <= length(gammaSeq) & beta == 1) %>%
  mutate(treatment = factor(treatmentLevels[treatment], levels = treatmentLevels))
allEstimates <- oasis %>% filter(treatment %in% treatmentLevels & beta == 1) %>% mutate(treatment = factor(treatment, levels = treatmentLevels)) %>%
  left_join(allClusterBasedEstimates)

allEstimates <- allEstimates %>% select(treatment, delta, avgDegree, 
                                       experimentId, afterAdjustment,afterAdjustmentBootVar, 
                                       clusterBased, clusterBasedBootVar, trueAvgResponse)
allEstimates <- allEstimates %>% 
  select(delta, avgDegree, experimentId, treatment, trueAvgResponse,
         afterAdjustment, afterAdjustmentBootVar, clusterBased, clusterBasedBootVar) %>%
  split(allEstimates$treatment) 
allEstimates <- allEstimates[[2]]  %>% 
  inner_join(allEstimates[[1]], by = c("delta", "avgDegree", "experimentId"))

allEstimates <- allEstimates %>% mutate(treatmentEffect = (trueAvgResponse.x - trueAvgResponse.y),
                                        oasis = (afterAdjustment.x - afterAdjustment.y),
                                        oasisError = oasis - treatmentEffect,
                                        oasisVar = afterAdjustmentBootVar.x + afterAdjustmentBootVar.y,
                                        clusterBased = (clusterBased.x - clusterBased.y),
                                        clusterBasedError = clusterBased - treatmentEffect,
                                        clusterBasedVar = clusterBasedBootVar.x + clusterBasedBootVar.y)
  
confIntResults <- allEstimates %>% group_by(delta, avgDegree) %>% 
summarise(trueAvgResponse.x = first(trueAvgResponse.x),
          trueAvgResponse.y = first(trueAvgResponse.y),
          treatmentEffect = first(treatmentEffect), 
          oasisCoverage = mean(abs(oasis - treatmentEffect) / sqrt(oasisVar)  <= qnorm(0.975)),
          clusterBasedCoverage = mean(abs(clusterBased - treatmentEffect) / sqrt(clusterBasedVar)  <= qnorm(0.975)))
print.xtable(xtable(confIntResults, digits = 3), include.rownames=FALSE)

allEstimates <- allEstimates %>%  group_by(delta, avgDegree) %>% gather(key = type, value = error, oasisError, clusterBasedError)


pdf(paste0("plots/", "estimation_error.pdf"), width = 9.25, height = 3)
ggplot(allEstimates, aes(x = factor(delta), 
                        y = error, 
                        fill = type, linetype = factor(avgDegree))) + geom_hline(yintercept=0, color = "grey") + stat_boxplot(geom ='errorbar') +
  geom_boxplot(outlier.shape = NA) + coord_cartesian(ylim = c(-0.1, 0.1)) + 
  theme_bw() + 
  theme(plot.margin = margin(0, 0, 0, 0, "cm"),
        axis.text.x = element_text(size = 18),
        axis.text.y = element_text(size = 18),
        axis.title.y = element_text(size = 18),
        axis.title.x = element_blank(),
        strip.text.x = element_text(size = 18),
        legend.position = "bottom", 
        legend.title = element_blank(),
        legend.text = element_text(size = 18)) +
  scale_x_discrete(labels = c(expression(delta * " = 0.25"), expression(delta * " = 0.5"), expression(delta * " = 1"))) + 
  scale_fill_discrete(labels = c(" CB ", " OASIS ")) + 
  scale_linetype_discrete(labels = c(" avgDegree=21 ", " avgDegree=55 ", " avgDegree=120 "))
dev.off()