# Data analysis and plots for NeurIPS submission
# NOTE: before running this script, set working directory to the one where this file is located,
# e.g. with Session > Set Working Directory > To Source File Location in RStudio

library(tidyverse)
library(RColorBrewer)

# Nazari 10 cities

rl_11 = read_csv("paper_data/cvrp_nazari_cities_11_rl.csv") 
rl_11_hp = read_csv("paper_data/highly_parallel_11_cities.csv") %>%
  rename(best_distance_so_far = distance)
rl_11_final = read_csv("paper_data/rl_best/toplevel/11_cities_full.csv")

rl_11 %>%
  group_by(name) %>%
  summarize(best_distance_so_far = min(best_distance_so_far)) %>%
  summary()

rl_11_final %>%
  left_join(or_tools_11 %>%
                           select(name, distance) %>%
                           rename(or_tools = distance)) %>%
  mutate(gap = (distance - or_tools) / or_tools) %>%
  filter(iteration == 50) %>%
  group_by(`hidden-nodes`) %>%
  summarize(distance = mean(distance), gap = mean(gap))

or_tools_11 = read_csv("paper_data/rl_best/or-tools/cities_11_capacity_20_instances_1000_result.csv") %>%
  separate(name, c("name.1", "name.2"), sep="_number_") %>%
  mutate(name = paste("instance_", name.2, sep=""))

mean(or_tools_11$best_distance_so_far)

rl_11 %>%
  filter(name == "instance_5") %>%
  ggplot() +
  aes(x = iteration, y = distance) +
  geom_line()

rl_11_hp %>%
  #select(-distance) %>%
  left_join(or_tools_11 %>% 
              mutate(name = paste("instance", row_number() - 1, sep="_")) %>%
              select(c(name, distance))) %>%
  mutate(gap = (best_distance_so_far - distance)/distance) %>%
  group_by(iteration) %>%
  summarize(best_distance_so_far = mean(best_distance_so_far),
            gap = mean(gap)) %>%
  ggplot() +
  aes(x = iteration, y = gap) +
  geom_line() + geom_point()

# Nazari 20 cities

rl_21 = read_csv("paper_data/cvrp_nazari_cities_21_rl_999.csv")
rl_21 %>%
  group_by(name) %>%
  summarize(distance = min(distance)) %>%
  summary()

rl_21_final = read_csv("paper_data/rl_best/toplevel/cities_21_restricted_arch_full.csv")

rl_21_final %>%
  left_join(or_tools_21 %>%
              select(name, distance) %>%
              rename(or_tools = distance)) %>%
  mutate(gap = (distance - or_tools) / or_tools) %>%
  filter(iteration == 250) %>%
  group_by(`hidden-nodes`) %>%
  summarize(meangap = mean(gap), meandist = mean(distance), count = n())

or_tools_21 = read_csv("paper_data/rl_best/or-tools/cities_21_capacity_30_instances_1000_result.csv") %>%
  separate(name, c("name.1", "name.2"), sep="_number_") %>%
  mutate(name = paste("instance_", name.2, sep=""))
summary(or_tools_21$best_distance_so_far)

# 51 cities

or_tools_51 = read_csv("paper_data/rl_best/or-tools/cities_51_capacity_40_instances_1000_result.csv") %>%
  select(c(name, distance)) %>%
  rename(or_tools = distance)

rl_51 = read_csv("paper_data/rl_best/toplevel/cities_51_restricted_arch_scrape.csv")

rl_51 %>%
  group_by(name, `hidden-nodes`) %>%
  summarize(iter = max(iteration), distance = min(distance)) %>%
  view()
  ungroup() %>%
  group_by(`hidden-nodes`) %>%
  summarize(meandist = mean(distance))
  
rl_51 %>%
  mutate(hidden_nodes = as.factor(`hidden-nodes`)) %>%
  left_join(or_tools_51) %>%
  mutate(gap = (distance - or_tools) / or_tools) %>%
  group_by(iteration, hidden_nodes) %>%
  summarize(meandist = mean(distance), stddist = sd(distance), count = n(),
            meangap = mean(gap), stdgap = sd(gap)) %>%
  ggplot() +
  aes(x = iteration, y = meangap,
      ymin = meangap - stdgap / sqrt(count),
      ymax = meangap + stdgap / sqrt(count),
      color = hidden_nodes, group = hidden_nodes,
      fill = hidden_nodes, linetype=hidden_nodes) +
  geom_line() +
  geom_ribbon(alpha = 0.2)
  
# summarize comparison study
comparison = data.frame() %>%
  rbind(or_tools_11 %>%
          summarize(meandist = mean(best_distance_so_far), stddist = sd(best_distance_so_far), count = n()) %>%
          mutate(name = "or-tools", size = 10)) %>%
  rbind(or_tools_21 %>%
          summarize(meandist = mean(distance), stddist = sd(distance), count = n()) %>%
          mutate(name = "or-tools", size = 20)) %>%
  rbind(rl_11_final %>%
          #filter(iteration == 50) %>%
          filter(`hidden-nodes` == 16) %>%
          group_by(name) %>%
          summarize(distance = min(distance)) %>%
          ungroup() %>%
          summarize(meandist = mean(distance), stddist = sd(distance), count = n()) %>%
          mutate(name = "rl_tf_opt", size=10)) %>%
  rbind(rl_21_final %>%
          filter(`hidden-nodes` == 16) %>%
          group_by(name) %>%
          summarize(distance = min(distance)) %>%
          ungroup() %>%
          summarize(meandist = mean(distance), stddist = sd(distance), count = n()) %>%
          mutate(name = "rl_tf_opt", size = 20)) %>%
  rbind(rl_11 %>%
          filter(iteration == 0) %>%
          summarize(meandist = mean(distance), stddist = sd(distance), count = n()) %>%
          mutate(name = "greedy", size = 10)) %>%
  rbind(rl_21 %>%
          filter(iteration == 0) %>%
          summarize(meandist = mean(distance), stddist = sd(distance), count = n()) %>%
          mutate(name = "greedy", size = 20)) %>%
  rbind(data.frame(meandist = 4.68, stddist = 0.82, name = "nazari", size = 10, count = 1000)) %>%
  rbind(data.frame(meandist = 6.40, stddist = 0.86, name = "nazari", size = 20, count = 1000)) %>%
  rbind(data.frame(meandist = 6.25, stddist = 0.0, name = "kool", size = 20, count = 10000)) %>%
  rbind(data.frame(meandist = 11.15, stddist = 1.28, name = "nazari", size = 50, count = 1000)) %>%
  rbind(data.frame(meandist = 10.62, stddist = 0.0, name = "kool", size = 50, count = 1000)) %>%
  rbind(rl_51 %>%
          filter(is.na(`hidden-nodes`)) %>%
          filter(iteration == 0) %>%
          summarize(meandist = mean(distance), stddist = sd(distance), count = n()) %>%
          mutate(name = "greedy", size = 50)) %>%
  rbind(or_tools_51 %>%
          summarize(meandist = mean(or_tools), stddist = sd(or_tools), count = n()) %>%
          mutate(name = "or-tools", size = 50)) %>%
  rbind(rl_51 %>%
          filter(`hidden-nodes` == 16) %>%
          group_by(name) %>%
          summarize(distance = min(distance)) %>%
          ungroup() %>%
          summarize(meandist = mean(distance), stddist = sd(distance), count = n()) %>%
          mutate(name = "rl_tf_opt", size = 50))

comparison %>%
  mutate_at(c("size"), as.factor) %>%
  mutate(stderr = stddist / sqrt(count)) %>%
  mutate(name = factor(name, levels = c("greedy", "nazari", "kool", "rl_tf_opt", "or-tools"))) %>%
  ggplot() +
  aes(x = size, y = meandist, ymin = meandist - stderr, ymax = meandist + stderr, group = name, color = name) +
  geom_errorbar(position = position_dodge(width = 0.9), width = 0.5, size=2) +
  geom_point(position = position_dodge(width = 0.9), size=3) +
  scale_color_brewer(palette = "Set2", name="Method",
                     labels = c("Greedy", "Nazari et al.", "Kool et al.",  "RLCA", "OR-Tools")) +
  theme(legend.position = c(0.25, 0.7), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  scale_x_discrete(labels = c("11", "21"), name = "Number of cities") +
  labs(y = "Average distance")
ggsave("paper_plots/comparison.png")

# Nazari 10 cities - data requirement study

data_requirements_11 = read_csv("paper_data/rl_best/data-requirements/cities_11_data_requirements_scrape.csv") %>%
  rename(num_evals = `num-evals`) %>%
  left_join(or_tools_11 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools)

data_requirements_21 = read_csv("paper_data/rl_best/data-requirements/cities_21_data_requirements_scrape.csv") %>%
  rename(num_evals = `num-evals`) %>%
  left_join(or_tools_21 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools)

data_requirements_21 %>%
  filter(iteration == 30) %>%
  group_by(num_evals) %>%
  summarize(count = n())

# data decay rate, distance at iteration 10, and number of evals
data_requirements_21 %>%
  filter(iteration == 30) %>%
  mutate_at(c("data_decay_rate", "num_evals"), as.factor) %>%
  group_by(num_evals, keep_data, data_decay_rate) %>%
  summarize(distance = mean(distance), meangap = mean(gap), stdgap = sd(gap)) %>%
  ggplot() +
  aes(x = num_evals, y=meangap, group=data_decay_rate, fill=data_decay_rate) +
  scale_fill_brewer(palette = "OrRd") +
  geom_bar(stat = "identity", position = position_dodge(), color="black") +
  theme(legend.position = c(0.8, 0.7), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")
        ) +
  labs(fill = "Retention factor", x = "Sample paths per iteration (N)", y = "Average gap with OR-Tools") +
  #scale_y_continuous(breaks = c(0, 0.01, 0.02, 0.03), labels = c("0%", "1%", "2%", "3%")) +
  scale_y_continuous(breaks = c(0, 0.05, 0.10, 0.15), labels = c("0%", "5%", "10%", "15%")) +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap/sqrt(50), ymax = meangap + stdgap/sqrt(50)),
                position = position_dodge(width = 0.9), width=0.5)
ggsave("paper_plots/num_evals_21.png")#, units = "in", width=9, height = 7.5)

data_requirements_21 %>%
  filter(iteration <= 30) %>%
  filter(data_decay_rate == 1.0) %>%
  mutate_at(c("data_decay_rate", "num_evals"), as.factor) %>%
  group_by(num_evals, iteration) %>%
  summarize(distance = mean(distance), meangap = mean(gap), stdgap = sd(gap)) %>%
  ggplot() +
  aes(x = iteration, y=meangap, group=num_evals, color=num_evals, fill=num_evals,
      linetype=num_evals) +
  scale_color_brewer(palette = "Set2") +
  scale_fill_brewer(palette = "Set2") +
  geom_line(size=2) +
  theme(legend.position = c(0.8, 0.7), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")
  ) +
  labs(x = "Policy iteration", fill = "Sample paths\nper iteration", y = "Average gap with OR-Tools",
       color = "Sample paths\nper iteration", linetype = "Sample paths\nper iteration") +
  #scale_y_continuous(breaks = c(0, 0.04, 0.08), labels = c("0%", "4%", "8%")) +
  scale_y_continuous(breaks = c(0, 0.05, 0.10, 0.15), labels = c("0%", "5%", "10%", "15%"), lim = c(0, 0.18)) +
  geom_ribbon(aes(y = meangap, ymin = meangap - stdgap/sqrt(50), ymax = meangap + stdgap/sqrt(50)),
              alpha = 0.2) 
ggsave("paper_plots/parallelism_21.png")

# progression of best distance with cumulative number of evals
# data_requirements %>%
#   mutate(total_evals = iteration * num_evals) %>%
#   mutate_at(c("data_decay_rate", "num_evals"), as.factor) %>%
#   filter(data_decay_rate == 1.0) %>%
#   group_by(total_evals, num_evals) %>%
#   summarize(distance = mean(distance)) %>%
#   ggplot() +
#   aes(x = total_evals, y=distance, color=num_evals, group=num_evals) +
#   geom_line()

# Nazari 10 cities - data requirement study part 2

#policy_iteration_11 = read_csv("paper_data/cities_11_policy_iterations.csv") %>%
policy_iteration_11 = read_csv("paper_data/rl_best/policy_iterations/cities_11_policy_iteration_ablation.csv") %>%
  rename(num_evals = `num-evals`) %>%
  left_join(or_tools_11 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools)

policy_iteration_21 = read_csv("paper_data/rl_best/policy_iterations/cities_21_policy_iterations.csv") %>%
  rename(num_evals = `num-evals`) %>%
  left_join(or_tools_21 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools)

# distance plot
# policy_iteration %>%
#   mutate(total_evals = iteration * num_evals) %>%
#   mutate_at(c("num_evals"), as.factor) %>%
#   group_by(total_evals, num_evals) %>%
#   summarize(distance = mean(distance)) %>%
#   ggplot() +
#   aes(x = total_evals, y=distance, color=num_evals, group=num_evals) +
#   geom_line()

# gap plot
policy_iteration_11 %>%
  mutate(total_evals = iteration * num_evals) %>%
  mutate_at(c("num_evals"), as.factor) %>%
  group_by(total_evals, num_evals) %>%
  summarize(stdgap = sd(gap), gap = mean(gap)) %>%
  ggplot() +
  aes(x = total_evals, y=gap, color=num_evals, group=num_evals, linetype=num_evals, fill=num_evals,
      ymin = gap - stdgap/sqrt(50), ymax = gap + stdgap/sqrt(50)) +
  geom_line(size = 2) +
  geom_ribbon(alpha = 0.2) +
  scale_color_brewer(palette = "Set2") +
  scale_fill_brewer(palette = "Set2") +
  theme(legend.position = c(0.8, 0.7), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  labs(color = "Sample paths\nper iteration (N)", linetype = "Sample paths\nper iteration (N)",
       fill = "Sample paths\nper iteration (N)",
       x = "Total sample paths evaluated", y = "Average gap with OR-Tools") +
  # scale_y_continuous(breaks = c(0.0, 0.02, 0.04, 0.06, 0.08, 0.1),
  #                    labels = c("0%", "2%", "4%", "6%", "8%", "10%")) +
  scale_y_continuous(breaks = c(0.0, 0.04, 0.08, 0.12, 0.16),
                     labels = c("0%", "4%", "8%", "12%", "16%"))

ggsave("paper_plots/cumulative_11.png")#,
       #units = "in", width=9, height = 7.5)

# average gap
policy_iteration_21 %>%
  group_by(name, num_evals) %>%
  summarize(gap = min(gap)) %>%
  ungroup() %>%
  group_by(num_evals) %>%
  summarize(gap = mean(gap)) %>%
  head()

policy_iteration_21 %>%
  filter(num_evals == 10) %>%
  group_by(name) %>%
  summarize(distance = min(distance), ortools=min(ortools)) %>%
  summary()

# L1/training study

l1_reg_11 = read_csv("paper_data/rl_best/training/cities_11_l1_reg.csv") %>%
  rename(num_evals = `num-evals`, lasso = `lasso-weight`) %>%
  mutate(description_split = description) %>%
  separate(description_split, c("trash.1", "description_split"), sep="rate_") %>%
  separate(description_split, c("learning_rate", "trash.2"), sep="_l1") %>%
  mutate(trash.1 = NULL, trash.2 = NULL, learning_rate = as.numeric(learning_rate)) %>%
  left_join(or_tools_11 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools) %>%
  filter(iteration == 50) %>%
  mutate_at(c("lasso", "learning_rate"), as.factor) %>%
  group_by(learning_rate, lasso) %>%
  summarize(meangap = mean(gap), meanmse=mean(mse),
            failures = mean(total_failed_training_attempts),
            stdgap = sd(gap), stdmse = sd(mse), count = n())

l1_reg_21 = read_csv("paper_data/rl_best/training/cities_21_l1_reg.csv") %>%
  rename(num_evals = `num-evals`, lasso = `lasso-weight`) %>%
  mutate(description_split = description) %>%
  separate(description_split, c("trash.1", "description_split"), sep="rate_") %>%
  separate(description_split, c("learning_rate", "trash.2"), sep="_l1") %>%
  mutate(trash.1 = NULL, trash.2 = NULL, learning_rate = as.numeric(learning_rate)) %>%
  left_join(or_tools_21 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools) %>%
  filter(iteration == 50) %>%
  mutate_at(c("lasso", "learning_rate"), as.factor) %>%
  group_by(learning_rate, lasso) %>%
  summarize(meangap = mean(gap), meanmse=mean(mse),
            failures = mean(total_failed_training_attempts),
            stdgap = sd(gap), stdmse = sd(mse), count = n())

# gap plot
l1_reg_21 %>%
  ggplot() +
  aes(x = lasso, y = meangap, fill = learning_rate, group = learning_rate) +
  geom_bar(stat="identity", position = position_dodge(), color="black") +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap/sqrt(count), ymax = meangap + stdgap/sqrt(count)),
                position = position_dodge(width = 0.9), width=0.5) +
  scale_fill_brewer(name = "Learning rate", palette = "OrRd", labels = c("5e-4", "1e-3", "5e-3")) +
  labs(x = "Lasso parameter", y = "Average gap with OR-Tools") +
  scale_y_continuous(breaks = c(0.0, 0.02, 0.04, 0.06, 0.08), labels = c("0%", "2%", "4%", "6%", "8%")) +
  theme(legend.position = c(0.5, 0.75), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line"))
ggsave("paper_plots/l1_gap_21.png")

l1_reg_21 %>%
  ggplot() +
  aes(x = lasso, y = meanmse, fill = learning_rate, group = learning_rate) +
  geom_bar(stat="identity", position = position_dodge(), color="black") +
  geom_errorbar(aes(y = meanmse, ymin = meanmse - stdmse/sqrt(count), ymax = meanmse + stdmse/sqrt(count)),
                position = position_dodge(width = 0.9), width=0.5) +
  scale_fill_brewer(name = "Learning rate", palette = "OrRd", labels = c("5e-4", "1e-3", "5e-3")) +
  labs(x = "Lasso parameter", y = "MSE (in-sample)") +
  theme(legend.position = c(0.5, 0.8), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line"))
ggsave("paper_plots/l1_mse_21.png")

# Ablation study - zeroing
zero_11 = read_csv("paper_data/cities_11_zeroing_incomplete.csv") %>%
  rename(num_evals = `num-evals`, lasso = `lasso-weight`) %>%
  mutate(description_split = description) %>%
  separate(description_split, c("trash.1", "description_split"), sep="zeroing_") %>%
  separate(description_split, c("zeroing", "trash.2"), sep="_l1") %>%
  mutate(trash.1 = NULL, trash.2 = NULL, zeroing = as.numeric(zeroing)) %>%
  left_join(or_tools_11 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools) %>%
  filter(iteration == 50) %>%
  mutate_at(c("lasso", "zeroing"), as.factor) %>%
  group_by(zeroing, lasso) %>%
  summarize(meangap = mean(gap), meanmse=mean(mse),
            failures = mean(total_failed_training_attempts),
            stdgap = sd(gap), stdmse = sd(mse), count = n())

zero_11 %>%
  ggplot() +
  aes(x = zeroing, y = meangap, fill = lasso, group = lasso) +
  geom_bar(stat="identity", position = position_dodge(), color="black") +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap/sqrt(count), ymax = meangap + stdgap/sqrt(count)),
                position = position_dodge(width = 0.9), width=0.5) +
  scale_fill_brewer(palette = "OrRd")

# Ablation study - MIP solve time
mip_11 = read_csv("paper_data/cities_11_mip_bench.csv")

mip_51 = read_csv("paper_data/cities_51_mip_bench_600s.csv")

mip_51 %>%
  filter(Runtime > 600) %>%
  group_by(TspCut, NNCut, BoundTightener) %>%
  summarize(meantime = mean(Runtime), stdtime = sd(Runtime), count = n(),
            meanobj = mean(Objective), stdobj = sd(Objective),
            meangap = mean(Gap), stdgap = sd(Gap)) %>%
  ungroup() %>%
  unite(Cut, c("TspCut", "NNCut"), sep="_") %>%
  ggplot() +
  aes(x = Cut, y = meangap, fill = BoundTightener) +
  geom_bar(stat = "identity", position = position_dodge()) +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap/sqrt(count),
                    ymax = meangap + stdgap/sqrt(count)),
                position = position_dodge(width = 0.9), width=0.5) +
  scale_x_discrete(labels = c("no cuts", "nn", "tsp", "tsp + nn"))

mip_51 %>%
  filter(Runtime < 600) %>%
  group_by(TspCut, NNCut, BoundTightener) %>%
  summarize(count = n()) %>%
  ungroup() %>%
  unite(Cut, c("TspCut", "NNCut"), sep="_") %>%
  ggplot() +
  aes(x = Cut, y=count, fill=BoundTightener) +
  geom_bar(stat = "identity", position = position_dodge()) +
  scale_x_discrete(labels = c("no cuts", "nn", "tsp", "tsp + nn"))
  
mip_51 %>%
  filter(Runtime < 600) %>%
  filter(BoundTightener == "interval_arithmetic") %>%
  unite(Cut, c("TspCut", "NNCut"), sep="_") %>%
  ggplot() +
  aes(x = Cut, y = Runtime) +
  geom_boxplot() +
  scale_x_discrete(labels = c("no cuts", "nn", "tsp", "tsp + nn"))

# Ablation 6 - architecture

arch_11 = read_csv("paper_data/rl_best/arch/cities_11_arch.csv") %>%
  left_join(or_tools_11 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools) %>%
  rename(hidden_nodes = `hidden-nodes`, lasso = `lasso-weight`)

arch_21 = read_csv("paper_data/rl_best/arch/cities_21_arch_scrape.csv") %>%
  left_join(or_tools_21 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools) %>%
  rename(hidden_nodes = `hidden-nodes`, lasso = `lasso-weight`)

arch_21 %>%
  filter(iteration == 230) %>%
  summarize(count = n())

arch_21 %>%
  filter(iteration == 230) %>%
  #filter(lasso == 0) %>%
  group_by(hidden_nodes, lasso) %>%
  summarize(meangap = mean(gap), stdgap = sd(gap), count = n(), meandist = mean(distance)) %>%
  ungroup() %>%
  mutate_at(c("hidden_nodes", "lasso"), as.factor) %>%
  ggplot() +
  aes(x = hidden_nodes, y = meangap, fill = lasso) +
  geom_bar(stat = "identity", position = position_dodge(), color="black") +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap / sqrt(count), ymax = meangap + stdgap/sqrt(count)),
                position = position_dodge(width = 0.9), width = 0.5) + 
  scale_fill_brewer(palette = "OrRd", name = "Lasso\nparameter") +
  theme(legend.position = c(0.5, 0.75), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  scale_y_continuous(breaks = c(0, 0.005, 0.01, 0.015, 0.02), labels = c("0%", "0.5%", "1%", "1.5%", "2%"),
                     name = "Average gap with OR-Tools") +
  scale_x_discrete(name = "Hidden ReLU nodes",
                   labels = c("4", "8", "16", "32", "0 (LR)"))
ggsave("paper_plots/arch_21.png")


# Ablation 7  starter

starter_11 = read_csv("paper_data/cities_11_starter.csv") %>%
  left_join(or_tools_11 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools) %>%
  separate(description, c("description", "starter_mult"), sep="mult_") %>%
  separate(starter_mult, c("starter_mult", "right.trash"), sep="-r") %>%
  mutate_at(c("starter_mult"), as.numeric) %>%
  rename(num_evals = `num-evals`)

starter_21 = read_csv("paper_data/cities_21_starter_incomplete.csv") %>%
  left_join(or_tools_21 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools) %>%
  separate(description, c("description", "starter_mult"), sep="mult_") %>%
  separate(starter_mult, c("starter_mult", "right.trash"), sep="-r") %>%
  mutate_at(c("starter_mult"), as.numeric) %>%
  rename(num_evals = `num-evals`)

starter_21 %>%
  group_by(name, num_evals, data_decay_rate, starter_mult) %>%
  summarize(gap = min(gap), distance = min(distance)) %>%
  ungroup() %>%
  mutate_at(c("num_evals", "data_decay_rate", "starter_mult"), as.factor) %>%
  unite(eval_pattern, c("num_evals", "starter_mult"), sep="_") %>%
  group_by(eval_pattern, data_decay_rate) %>%
  summarize(meangap = mean(gap), stdgap = sd(gap), count = n()) %>%
  ggplot() +
  aes(x = eval_pattern, y = meangap, fill = data_decay_rate) +
  geom_bar(stat = "identity", position = position_dodge(), color="black") +
  geom_errorbar(aes(y = meangap, ymin = meangap - stdgap / sqrt(count), ymax = meangap + stdgap/sqrt(count)),
                position = position_dodge(width = 0.9), width = 0.5) + 
  scale_fill_brewer(palette = "Set2") +
  theme_minimal()

starter_21 %>%
  filter(num_evals == 10) %>%
  mutate_at(c("starter_mult", "data_decay_rate"), as.factor) %>%
  filter(starter_mult == 1000) %>%
  #filter(data_decay_rate == 1.0) %>%
  group_by(iteration, data_decay_rate) %>%
  summarize(meangap = mean(gap), stdgap = sd(gap), count = n()) %>%
  ggplot() +
  aes(x = iteration, y = meangap, color = data_decay_rate, group = data_decay_rate) +
  geom_line(size = 1) +
  geom_point() +
  #geom_errorbar(aes(y = meangap, ymin = meangap - stdgap / sqrt(count), ymax = meangap + stdgap/sqrt(count))) +
  theme_minimal() +
  scale_color_brewer(palette = "Set1")

# gurobi/scip analysis

gurobi_files = list.files("paper_data/rl_best/gurobi/")
gurobi = data.frame()
for (file in gurobi_files) {
  gurobi = gurobi %>%
    rbind(read_csv(paste("paper_data/rl_best/gurobi/", file, sep="")) %>%
            mutate(options = file, solver = "gurobi"))
}
scip_files = list.files("paper_data/rl_best/scip/")
for (file in gurobi_files) {
  gurobi = gurobi %>%
    rbind(read_csv(paste("paper_data/rl_best/scip/", file, sep="")) %>%
            mutate(options = file, solver = "scip"))
}

gurobi %>%
  filter(grepl("51", options)) %>%
  filter(Runtime < 180) %>%
  mutate_at(c("NNCut"), as.factor) %>%
  unite(solver, c("solver", "NNCut"), sep = "_") %>%
  group_by(solver, options) %>%
  summarize(meantime = mean(Runtime), stdtime = sd(Runtime),
            meangap = mean(Gap), stdgap = sd(Gap), count = n()) %>%
  ggplot() +
  aes(x = options, y = meantime, fill = solver) +
  geom_bar(stat = "identity", position = position_dodge(), color="black") +
  geom_errorbar(aes(y = meantime, ymin = meantime - stdtime / sqrt(count), ymax = meantime + stdtime/sqrt(count)),
                position = position_dodge(width = 0.9), width = 0.5) +
  scale_fill_brewer(palette = "Set2") +
  theme_minimal()


gurobi %>%
  filter(grepl("51", options)) %>%
  filter(!grepl("0.1", options)) %>%
  filter(NNCut == 1) %>%
  mutate_at(c("NNCut"), as.factor) %>%
  unite(solver, c("solver", "NNCut"), sep = "_") %>%
  group_by(solver, options) %>%
  summarize(meantime = mean(Runtime), stdtime = sd(Runtime),
            meangap = mean(Gap), stdgap = sd(Gap), count = n()) %>%
  mutate(stderr = stdtime / sqrt(count)) %>%
  view()

# Final ablation - lower bounds

comblb_11 = read_csv("paper_data/rl_best/comb_lb/cities_11_comb_lb.csv") %>%
  left_join(or_tools_11 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools)

comblb_21 = read_csv("paper_data/rl_best/comb_lb/cities_21_comb_lb_scrape.csv") %>%
  left_join(or_tools_21 %>%
              select(name, distance) %>%
              rename(ortools = distance)) %>%
  mutate(gap = (distance - ortools) / ortools)

comblb_21 %>%
  group_by(iteration) %>%
  summarize(count = n()) %>%
  view()

comblb_21 %>%
  filter(iteration < 160) %>%
  unite(lower_bounds, c("lb-vehicles", "lb-train"), sep="_") %>%
  group_by(iteration, lower_bounds) %>%
  summarize(meangap = mean(gap), stdgap = sd(gap), count = n()) %>%
  ggplot() +
  aes(x = iteration, y = meangap, group = lower_bounds, color=lower_bounds, linetype=lower_bounds, fill=lower_bounds) +
  geom_line(size = 2) +
  geom_ribbon(aes(y = meangap, ymin = meangap - stdgap / sqrt(count), ymax = meangap + stdgap/sqrt(count)), alpha=0.2) + 
  scale_fill_brewer(palette = "Set2", labels = c("None", "Eval. only", "Eval. + training")) +
  scale_color_brewer(palette = "Set2", labels = c("None", "Eval. only", "Eval. + training")) +
  scale_linetype_discrete(labels = c("None", "Eval. only", "Eval. + training")) +
  theme(legend.position = c(0.7, 0.7), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  labs(color = "Combinatorial\nlower bounds",
       linetype = "Combinatorial\nlower bounds",
       fill = "Combinatorial\nlower bounds",
       x = "Policy iteration", y = "Average gap with OR-Tools") +
  scale_y_continuous(breaks = c(0.0, 0.04, 0.08, 0.12, 0.16),
                     labels = c("0%", "4%", "8%", "12%", "16%"))
ggsave("paper_plots/comblb_21.png")

# cvrplib

cvrplib = read_csv("paper_data/rl_best/cvrplib/restricted_arch_cvrplib_scrape.csv")

read_csv("paper_data/rl_best/cvrplib/cvrplib_instance_ordering.csv", headers = F)

names = read_csv("paper_data/rl_best/cvrplib/cvrplib_instance_ordering.csv") %>%
  mutate(name = paste("instance", as.numeric(rownames(names)) - 1,
                          sep = "_"))

cvrplib_or_tools = read_csv("paper_data/rl_best/cvrplib/cvrplib_ab_result_gls_300.csv") %>%
  select(name, distance) %>%
  rename(or_tools = distance)

cvrplib_or_tools_fast = read_csv("paper_data/rl_best/cvrplib/cvrplib_ab_result_fast.csv") %>%
  select(name, distance) %>%
  rename(or_tools_fast = distance)

cvrplib  = cvrplib %>%
  left_join(names) %>%
  left_join(cvrplib_or_tools) %>%
  left_join(cvrplib_or_tools_fast) %>%
  mutate(gap = (distance - or_tools)/or_tools,
         gap_fast = (or_tools_fast - or_tools)/or_tools)

cvrplib %>%
  group_by(name) %>%
  summarize(gap = min(gap)) %>%
  ggplot() +
  aes(y = gap) +
  geom_boxplot()

cvrplib %>%
  group_by(name, `hidden-nodes`) %>%
  summarize(distance = min(distance), iteration = max(iteration),
            gap = min(gap)) %>%
  #filter(iteration == 25) %>%
  group_by(`hidden-nodes`, iteration) %>%
  summarize(meangap = mean(gap), count = n()) %>%
  ggplot() +
  aes(x = iteration, y=meangap, color=`hidden-nodes`) +
  geom_point()

cvrplib %>%
  mutate(hidden_nodes = (`hidden-nodes`)) %>%
  filter(hidden_nodes == 16 | is.na(hidden_nodes)) %>%
  mutate(hidden_nodes = as.factor(ifelse(is.na(hidden_nodes), 0, hidden_nodes))) %>%
  group_by(iteration, hidden_nodes) %>%
  summarize(meangap = mean(gap), stdgap = sd(gap)/sqrt(n())) %>%
  ggplot() +
  aes(x = iteration, y = meangap, group = hidden_nodes,
      color = hidden_nodes, fill = hidden_nodes,
      linetype = hidden_nodes,
      ymin = meangap - stdgap, ymax = meangap + stdgap) +
  geom_line(size = 2) +
  geom_ribbon(alpha = 0.2) +
  theme(legend.position = c(0.75, 0.75), legend.text = element_text(size=28),
        legend.title = element_text(size=28), axis.title = element_text(size=28), 
        axis.text = element_text(size=28),
        legend.background = element_blank(),
        legend.box.background = element_rect(color="black"),
        panel.border = element_blank(),
        panel.background = element_blank(),
        panel.grid.major = element_line(colour = "gray", linetype="dashed"),
        panel.grid.minor = element_blank(),
        axis.line = element_line(colour = "black"),
        legend.spacing = unit(2, "line"),
        legend.key.height = unit(2, "line"),
        legend.key.width = unit(3, "line")) +
  scale_y_continuous(lim = c(0, 0.32), labels = c("0%", "10%", "20%", "30%"), name = "Average gap with OR-Tools (300s)") +
  scale_fill_brewer(palette = "Set2", name = "Hidden nodes") +
  scale_color_brewer(palette = "Set2") +
  labs(color = "Hidden nodes", linetype = "Hidden nodes", group = "Hidden nodes")
ggsave("paper_plots/cvrplib.png")

cvrplib %>%
  group_by(iteration) %>%
  summarize(count = n()) %>%
  view()
