# Set working directory to source file location

######################################################
# Table 1
# Updates:
# 1) Add row for CPSAT (optimal)
# 2) Report 1-sigma confidence interval instead of standard deviation
######################################################

###################
# n = 11

# OR-tools
or_tools_11 = read_csv("paper_data/rl_best/or-tools/cities_11_capacity_20_instances_1000_result.csv") %>%
  separate(name, c("name.1", "name.2"), sep="_number_") %>%
  mutate(name = paste("instance_", name.2, sep=""))

mean(or_tools_11$best_distance_so_far)
sd(or_tools_11$best_distance_so_far) / sqrt(1000)

# RLCA and Greedy
rl_11_final = read_csv("paper_data/rl_best/toplevel/11_cities_full.csv")
rl_11_final %>%
  left_join(or_tools_11 %>%
              select(name, distance) %>%
              rename(or_tools = distance)) %>%
  mutate(gap = (distance - or_tools) / or_tools) %>%
  filter(iteration %in% c(0, 50)) %>% # 0 for greedy, 50 for RLCA
  group_by(`hidden-nodes`, iteration) %>%
  summarize(meandistance = mean(distance), gap = mean(gap),
            stderr = sd(distance)/sqrt(1000))

# Optimal (CPSAT)
cpsat_11 = read_csv("paper_data/camera_ready_data/cpsat_nazari_11.csv") %>%
  separate(name, c("name.1", "name.2"), sep="_number_") %>%
  mutate(name = paste("instance_", name.2, sep=""))
mean(cpsat_11$distance)
sd(cpsat_11$distance) / sqrt(1000)

###################
# n = 21

# OR-Tools
or_tools_21 = read_csv("paper_data/rl_best/or-tools/cities_21_capacity_30_instances_1000_result.csv") %>%
  separate(name, c("name.1", "name.2"), sep="_number_") %>%
  mutate(name = paste("instance_", name.2, sep=""))

mean(or_tools_21$best_distance_so_far)
sd(or_tools_21$best_distance_so_far) / sqrt(1000)

# RLCA and greedy
rl_21_final = read_csv("paper_data/rl_best/toplevel/cities_21_restricted_arch_full.csv")
rl_21_final %>%
  left_join(or_tools_21 %>%
              select(name, distance) %>%
              rename(or_tools = distance)) %>%
  mutate(gap = (distance - or_tools) / or_tools) %>%
  filter(iteration %in% c(0, 250)) %>%
  group_by(iteration, `hidden-nodes`) %>%
  summarize(meangap = mean(gap), meandist = mean(distance),
            count = n(), stderr = sd(distance)/sqrt(count))

# Optimal (CPSAT)
cpsat_21 = read_csv("paper_data/camera_ready_data/cpsat_nazari_21.csv") %>%
  separate(name, c("name.1", "name.2"), sep="_number_") %>%
  mutate(name = paste("instance_", name.2, sep=""))
mean(cpsat_21$distance)
sd(cpsat_21$distance) / sqrt(1000)

###################
# n = 51

# OR-Tools
or_tools_51 = read_csv("paper_data/rl_best/or-tools/cities_51_capacity_40_instances_1000_result.csv") %>%
  select(c(name, distance)) %>%
  rename(or_tools = distance)
mean(or_tools_51$or_tools)
sd(or_tools_51$or_tools) / sqrt(1000)

# RLCA and Greedy
rl_51 = read_csv("paper_data/camera_ready_data/cr3_nazari51_h16_result.csv") %>%
  mutate(`hidden-nodes` = ifelse(is.na(`hidden-nodes`), 0, `hidden-nodes`))

rl_51 %>%
  mutate(hidden_nodes = as.factor(`hidden-nodes`)) %>%
  left_join(or_tools_51) %>%
  mutate(gap = (distance - or_tools) / or_tools) %>%
  filter(iteration %in% c(0, 100)) %>%
  group_by(iteration) %>%
  summarize(meandist = mean(distance), stddist = sd(distance), count = n(),
            meangap = mean(gap), stdgap = sd(gap)) %>%
  View()

cpsat_51 = read_csv("paper_data/camera_ready_data/cpsat_nazari_51.csv") %>%
  separate(name, c("name.1", "name.2"), sep="_number_") %>%
  mutate(name = paste("instance_", name.2, sep=""))

cpsat_51 %>%
  select(distance, `best-bound`) %>%
  summary()

cpsat_51 %>%
  filter(`best-bound` >= distance) %>%
  nrow()

##############
# A question: what is the gap between cpsat and or-tools on n=11 and n=21?
# n = 11
or_tools_11 %>%
  select(name, distance) %>%
  rename(or_tools = distance) %>%
  left_join(cpsat_11 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(diff = or_tools - cpsat) %>%
  filter(diff > 1e-8) %>%
  nrow() # number of instances where difference exceeds 1e-8

# n = 21
or_tools_21 %>%
  select(name, distance) %>%
  rename(or_tools = distance) %>%
  left_join(cpsat_21 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(diff = or_tools - cpsat,
         gap = diff/cpsat) %>%
  filter(diff > 1e-8) %>%
  nrow() # number of instances where difference exceeds 1e-8

######################################################
# Figure 1
######################################################

cvrplib = read_csv("paper_data/camera_ready_data/cr1_cvrplib_linear_result.csv") %>%
  rbind(read_csv("paper_data/camera_ready_data/cr2_cvrplib_h16_result.csv"))

names = read.csv("paper_data/rl_best/cvrplib/cvrplib_instance_ordering.csv", header=F) %>%
  mutate(name = paste("instance", as.numeric(rownames(names)) - 1,
                      sep = "_")) %>%
  separate(V1, sep=".vrp.bi", into = c("V1", "tmp")) %>%
  select(-c(tmp, V2)) %>%
  separate(V1, sep="cvrplib/", into = c("tmp", "V1")) %>%
  select(-tmp) %>%
  separate(V1, sep="n", into = c("tmp", "V1")) %>%
  select(-tmp) %>%
  separate(V1, sep="-k", into = c("n", "k")) %>%
  mutate_at(c("n", "k"), as.numeric)

cvrplib_or_tools = read_csv("paper_data/rl_best/cvrplib/cvrplib_ab_result_gls_300.csv") %>%
  select(name, distance) %>%
  rename(or_tools = distance)

cvrplib_or_tools_fast = read_csv("paper_data/rl_best/cvrplib/cvrplib_ab_result_fast.csv") %>%
  select(name, distance) %>%
  rename(or_tools_fast = distance)

cvrplib  = cvrplib %>%
  left_join(names) %>%
  left_join(cvrplib_or_tools) %>%
  left_join(cvrplib_or_tools_fast) %>%
  mutate(gap = (distance - or_tools)/or_tools,
         gap_fast = (or_tools_fast - or_tools)/or_tools)

cvrplib %>%
  filter(iteration == 100) %>%
  filter(`hidden-nodes` == 16) %>%
  ggplot() +
  aes(x = n, y = gap) +
  geom_point()

cvrplib %>%
  filter(iteration == 100) %>%
  filter(gap < 0) %>%
  View()

cvrplib %>%
  mutate(hidden_nodes = (`hidden-nodes`)) %>%
  filter(hidden_nodes == 16 | is.na(hidden_nodes)) %>%
  mutate(hidden_nodes = as.factor(ifelse(is.na(hidden_nodes), 0, hidden_nodes))) %>%
  group_by(iteration, hidden_nodes) %>%
  summarize(meangap = mean(gap), stdgap = sd(gap)/sqrt(n())) %>%
  ungroup() %>%
  filter(hidden_nodes==16) %>%
  #rbind(best_cvrplib) %>%
  ggplot() +
  aes(x = iteration, y = meangap, group = hidden_nodes,
      color = hidden_nodes, fill = hidden_nodes,
      linetype = hidden_nodes,
      ymin = meangap - stdgap, ymax = meangap + stdgap) +
  geom_line(size = 2) +
  geom_ribbon(alpha = 0.2) +
  theme(legend.position = c(0.75, 0.75), legend.text = element_text(size=24),
        legend.title = element_text(size=24), axis.title = element_text(size=24), 
        axis.text = element_text(size=24),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  scale_y_continuous(lim = c(0, 0.32), labels = c("0%", "10%", "20%", "30%"), name = "Average gap with OR-Tools") +
  scale_fill_brewer(palette = "Set2", name = "Hidden nodes") +
  scale_color_brewer(palette = "Set2") +
  labs(color = "Hidden nodes", linetype = "Hidden nodes", group = "Hidden nodes",
       x = "Policy iteration")
ggsave("paper_plots/cr/cvrplib-talk.png")

######################################################
# Figure 2 a/b
######################################################

data_requirements_11 = read_csv("paper_data/rl_best/data-requirements/cities_11_data_requirements_scrape.csv") %>%
  rename(num_evals = `num-evals`) %>%
  left_join(cpsat_11 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat)

data_requirements_21 = read_csv("paper_data/rl_best/data-requirements/cities_21_data_requirements_scrape.csv") %>%
  rename(num_evals = `num-evals`) %>%
  left_join(cpsat_21 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat)

data_requirements_11 %>%
  filter(iteration == 30) %>%
  mutate_at(c("data_decay_rate", "num_evals"), as.factor) %>%
  group_by(num_evals, keep_data, data_decay_rate) %>%
  summarize(distance = mean(distance), meangap = mean(gap), stdgap = sd(gap)) %>%
  ggplot() +
  aes(x = num_evals, y=meangap, group=data_decay_rate, fill=data_decay_rate) +
  scale_fill_brewer(palette = "OrRd") +
  geom_bar(stat = "identity", position = position_dodge(), color="black") +
  theme(legend.position = c(0.8, 0.7), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  labs(fill = "Retention factor", x = "Sample paths per iteration (N)", y = "Optimality gap") +
  #uncomment for n = 11
  scale_y_continuous(breaks = c(0, 0.01, 0.02, 0.03), labels = c("0%", "1%", "2%", "3%")) +
  #uncomment for n = 21
  #scale_y_continuous(breaks = c(0, 0.05, 0.10, 0.15), labels = c("0%", "5%", "10%", "15%")) +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap/sqrt(50), ymax = meangap + stdgap/sqrt(50)),
                position = position_dodge(width = 0.9), width=0.5)
ggsave("paper_plots/cr/num_evals_11.png")

######################################################
# Figure 2 c/d
######################################################
policy_iteration_11 = read_csv("paper_data/rl_best/policy_iterations/cities_11_policy_iteration_ablation.csv") %>%
  rename(num_evals = `num-evals`) %>%
  left_join(cpsat_11 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat)

policy_iteration_21 = read_csv("paper_data/rl_best/policy_iterations/cities_21_policy_iterations.csv") %>%
  rename(num_evals = `num-evals`) %>%
  left_join(cpsat_21 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat)

policy_iteration_21 %>%
  mutate(total_evals = iteration * num_evals) %>%
  mutate_at(c("num_evals"), as.factor) %>%
  group_by(total_evals, num_evals) %>%
  summarize(stdgap = sd(gap), gap = mean(gap)) %>%
  ggplot() +
  aes(x = total_evals, y=gap, color=num_evals, group=num_evals, linetype=num_evals, fill=num_evals,
      ymin = gap - stdgap/sqrt(50), ymax = gap + stdgap/sqrt(50)) +
  geom_line(size = 2) +
  geom_ribbon(alpha = 0.2) +
  scale_color_brewer(palette = "Set2") +
  scale_fill_brewer(palette = "Set2") +
  theme(legend.position = c(0.8, 0.7), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  labs(color = "Sample paths\nper iteration (N)", linetype = "Sample paths\nper iteration (N)",
       fill = "Sample paths\nper iteration (N)",
       x = "Total sample paths evaluated", y = "Optimality gap") +
  scale_y_continuous(breaks = c(0.0, 0.04, 0.08, 0.12, 0.16),
                     labels = c("0%", "4%", "8%", "12%", "16%"))
ggsave("paper_plots/cr/cumulative_21.png")

######################################################
# Figure 3
######################################################
l1_reg_11 = read_csv("paper_data/rl_best/training/cities_11_l1_reg.csv") %>%
  rename(num_evals = `num-evals`, lasso = `lasso-weight`) %>%
  mutate(description_split = description) %>%
  separate(description_split, c("trash.1", "description_split"), sep="rate_") %>%
  separate(description_split, c("learning_rate", "trash.2"), sep="_l1") %>%
  mutate(trash.1 = NULL, trash.2 = NULL, learning_rate = as.numeric(learning_rate)) %>%
  left_join(cpsat_11 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat) %>%
  filter(iteration == 50) %>%
  mutate_at(c("lasso", "learning_rate"), as.factor) %>%
  group_by(learning_rate, lasso) %>%
  summarize(meangap = mean(gap), meanmse=mean(mse),
            failures = mean(total_failed_training_attempts),
            stdgap = sd(gap), stdmse = sd(mse), count = n())

l1_reg_21 = read_csv("paper_data/rl_best/training/cities_21_l1_reg.csv") %>%
  rename(num_evals = `num-evals`, lasso = `lasso-weight`) %>%
  mutate(description_split = description) %>%
  separate(description_split, c("trash.1", "description_split"), sep="rate_") %>%
  separate(description_split, c("learning_rate", "trash.2"), sep="_l1") %>%
  mutate(trash.1 = NULL, trash.2 = NULL, learning_rate = as.numeric(learning_rate)) %>%
  left_join(cpsat_21 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat) %>%
  filter(iteration == 50) %>%
  mutate_at(c("lasso", "learning_rate"), as.factor) %>%
  group_by(learning_rate, lasso) %>%
  summarize(meangap = mean(gap), meanmse=mean(mse),
            failures = mean(total_failed_training_attempts),
            stdgap = sd(gap), stdmse = sd(mse), count = n())

# gap plot
l1_reg_21 %>%
  ggplot() +
  aes(x = lasso, y = meangap, fill = learning_rate, group = learning_rate) +
  geom_bar(stat="identity", position = position_dodge(), color="black") +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap/sqrt(count), ymax = meangap + stdgap/sqrt(count)),
                position = position_dodge(width = 0.9), width=0.5) +
  scale_fill_brewer(name = "Learning rate", palette = "OrRd", labels = c("5e-4", "1e-3", "5e-3")) +
  labs(x = "Lasso parameter", y = "Optimality gap") +
  scale_y_continuous(breaks = c(0.0, 0.02, 0.04, 0.06, 0.08), labels = c("0%", "2%", "4%", "6%", "8%")) +
  theme(legend.position = c(0.5, 0.75), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line"))
ggsave("paper_plots/cr/l1_gap_21.png")


######################################################
# Figure 4
######################################################
arch_11 = read_csv("paper_data/rl_best/arch/cities_11_arch.csv") %>%
  left_join(cpsat_11 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat) %>%
  rename(hidden_nodes = `hidden-nodes`, lasso = `lasso-weight`)

arch_21 = read_csv("paper_data/rl_best/arch/cities_21_arch_scrape.csv") %>%
  left_join(cpsat_21 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat) %>%
  rename(hidden_nodes = `hidden-nodes`, lasso = `lasso-weight`)

arch_21 %>%
  # filter(iteration == 50) %>% # for n = 11
  filter(iteration == 230) %>%
  #filter(lasso == 0) %>%
  group_by(hidden_nodes, lasso) %>%
  summarize(meangap = mean(gap), stdgap = sd(gap), count = n(), meandist = mean(distance)) %>%
  ungroup() %>%
  mutate_at(c("hidden_nodes", "lasso"), as.factor) %>%
  ggplot() +
  aes(x = hidden_nodes, y = meangap, fill = lasso) +
  geom_bar(stat = "identity", position = position_dodge(), color="black") +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap / sqrt(count), ymax = meangap + stdgap/sqrt(count)),
                position = position_dodge(width = 0.9), width = 0.5) + 
  scale_fill_brewer(palette = "OrRd", name = "Lasso\nparameter") +
  theme(legend.position = c(0.5, 0.75), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  scale_y_continuous(breaks = c(0, 0.005, 0.01, 0.015, 0.02), labels = c("0%", "0.5%", "1%", "1.5%", "2%"),
                     name = "Optimality gap") +
  scale_x_discrete(name = "Hidden ReLU nodes",
                   labels = c("4", "8", "16", "32", "0 (LR)"))
ggsave("paper_plots/cr/arch_21.png")

######################################################
# Figure 5
######################################################
comblb_11 = read_csv("paper_data/rl_best/comb_lb/cities_11_comb_lb.csv") %>%
  left_join(cpsat_11 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat)

comblb_21 = read_csv("paper_data/rl_best/comb_lb/cities_21_comb_lb_scrape.csv") %>%
  left_join(cpsat_21 %>%
              select(name, distance) %>%
              rename(cpsat = distance)) %>%
  mutate(gap = (distance - cpsat) / cpsat)

comblb_21 %>%
  filter(iteration < 160) %>%
  unite(lower_bounds, c("lb-vehicles", "lb-train"), sep="_") %>%
  group_by(iteration, lower_bounds) %>%
  summarize(meangap = mean(gap), stdgap = sd(gap), count = n()) %>%
  ggplot() +
  aes(x = iteration, y = meangap, group = lower_bounds, color=lower_bounds,
      linetype=lower_bounds, fill=lower_bounds) +
  geom_line(size = 2) +
  geom_ribbon(aes(y = meangap, ymin = meangap - stdgap / sqrt(count),
                  ymax = meangap + stdgap/sqrt(count)), alpha=0.2) + 
  scale_fill_brewer(palette = "Set2", labels = c("None", "Eval. only", "Eval. + training")) +
  scale_color_brewer(palette = "Set2", labels = c("None", "Eval. only", "Eval. + training")) +
  scale_linetype_discrete(labels = c("None", "Eval. only", "Eval. + training")) +
  theme(legend.position = c(0.7, 0.7), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  labs(color = "Combinatorial\nlower bounds",
       linetype = "Combinatorial\nlower bounds",
       fill = "Combinatorial\nlower bounds",
       x = "Policy iteration", y = "Optimality gap") +
  scale_y_continuous(breaks = c(0.0, 0.04, 0.08, 0.12, 0.16),
                     labels = c("0%", "4%", "8%", "12%", "16%"))
ggsave("paper_plots/cr/comblb_21.png")
