rm(list = ls())
library(rosqp)
library(Matrix)
library(dplyr)
library(data.table)

getOsqpDesign <- function(fullData) {
  # Function that returns the design matrices corresponding to the optimization problem
  # Min 0.5x'Px + x'q s.t. l <= Ax <= u.
  # See https://cran.r-project.org/web/packages/osqp/osqp.pdf for further details.
  #
  # Args:
  #   fullData: Data frame containing the following columns:
  #   (producerId, consumerId, htildeFinal, pijCurrent, lj, uj, lowerBound, upperBound)
  #
  # Returns:
  #   The objective matrix P, the objective vector q,
  #   the constraint matrix A and the lower and upper bound vectors l, u respectively.
  
  P <- 2 * crossprod(fac2sparse(fullData$producerId))
  q <- - 2 * fullData$htildeFinal
  
  firstMat <- Diagonal(nrow(fullData))
  secondMat <- fac2sparse(fullData$consumerId)
  A <- rbind(firstMat, secondMat)
  
  consumerData <- fullData %>% group_by(consumerId) %>% summarize(uj = first(uj), lj = first(lj))
  l <- c(fullData$lowerBound, consumerData$lj)
  u <- c(fullData$upperBound, consumerData$uj)
  return(list(P = P, q = q, A = A, l = l, u = u))
}

getHTildeFinal <- function(finalData, jSet) {
  # Function that adjusts the htilde accordingly corresponding to the chosen consumerSet
  #
  # Args:
  #   finalData: Data frame containing the full information-
  #   (index, producerId, lowerBound, h, consumerId, pijBase, upperBound, pijOld, pijCurrent, htilde)
  #   jSet: Set of consumers
  #
  # Returns:
  #   The producer set containing the correct htilde information to be used in the grouped
  # optimization problem.
  
  prodSet = finalData[which(finalData$consumerId %in% jSet),]
  prodSetFinal <- prodSet[, htildeFinal := first(htilde) + sum(pijCurrent), by = producerId]
  return(prodSetFinal)
}

getObjValue <- function(data, isFull = TRUE) {
  # Function that generates the objective value
  #
  # Args:
  #   data: Data frame containing the full information-
  #     (producerId, consumerId, lj, uj, lowerBound, upperBound, h, htilde, pijBase)
  #   isFull: A boolean indicator that differentiates if we are using the full data or the prodSet data.
  #
  # Returns:
  #   The objective value:
  #       average_{producerIds} (h_i - \sum_{j in Ci} p_{ij})^2 if isFull = TRUE
  #       average_{producerIds} (htildeFinal_i - \sum_{j in Ci} p_{ij})^2 if isFull = FALSE
  
  if (isFull) {
    prodData <- data %>%
      group_by(producerId) %>%
      summarize(total = first(h) - sum(pijCurrent))
  } else {
    prodData <- data %>%
      group_by(producerId) %>%
      summarize(total = first(htildeFinal) - sum(pijCurrent))
  }
  objValue <- sum((prodData$total)^2)
  return(objValue)
}

groupIterativeQPSolver <- function(finalData, groupSize, maxIter){
  # Function for solving a large-scale network QP problem.
  
  # The inputs
  # finalData : data containing all the necessary information
  # groupSize : The max size of consumer set for each run of the QP.
  # maxIter : The total number of iterations for convergence
  convergenceLimit <- 1e-3
  finalData$pijOld <- finalData$pijBase
  finalData$pijCurrent <- finalData$pijBase
  finalData$index <- 1:nrow(finalData)
  # The following steps are required to avoid unwanted sorting while applying dplyr::summmarize()
  finalData$producerId <- factor(finalData$producerId, levels = unique(finalData$producerId))
  finalData$consumerId <- factor(finalData$consumerId, levels = unique(finalData$consumerId))
  
  numProducers <- length(unique(finalData$producerId))
  percentageChangeInPijVal <- 1 + convergenceLimit
  percentageChangeInObjVal <- 1 + convergenceLimit
  objVal <- 1 + convergenceLimit
  count <- 0
  
  setup.time <- 0
  qp.time <- 0
  update.time <- 0
  
  
  ##################### Starting the Execution ###############################
  
  while ((count < maxIter) && (percentageChangeInPijVal > convergenceLimit) && (objVal > convergenceLimit) &&
         percentageChangeInObjVal > convergenceLimit) {
    
    # Setting Consumer Chunks
    consumerSet <- sample(unique(finalData$consumerId), replace = FALSE)
    consumerSetChunks <- split(consumerSet, ceiling(seq_along(consumerSet) / groupSize))
    chunkLen <- length(consumerSetChunks)
    
    objVal <- getObjValue(finalData) / numProducers
    
    if (count == 0){
      print(paste("Initial objective value:", objVal))
    }
    
    for (ct in 1 : length(consumerSetChunks)) {
      objValPrev <- objVal
      jSet <- consumerSetChunks[[ct]]
      
      startTime <- proc.time()
      prodSet <- getHTildeFinal(finalData, jSet)
      prodSetValueBefore <- getObjValue(prodSet, FALSE)
      
      osqpDesign <- getOsqpDesign(prodSet)
      endTime <- proc.time() - startTime
      setup.time <- setup.time + endTime[3]
      
      startTime <- proc.time()
      results <- solve_osqp(osqpDesign$P, osqpDesign$q, osqpDesign$A, osqpDesign$l, osqpDesign$u,
                            osqpSettings(eps_abs = 1e-6, eps_rel = 1e-6, verbose = FALSE))
      endTime <- proc.time() - startTime
      qp.time <- qp.time + endTime[3]
      
      startTime <- proc.time()
      prodSet$pijCurrent <- results$x
      prodSetValueAfter <- getObjValue(prodSet, FALSE)
      objVal <- objValPrev + (prodSetValueAfter - prodSetValueBefore) / numProducers
      print(paste0("Objective value after iteration ", count, " + (", ct, " / ", length(consumerSetChunks), "): ", objVal))
      
      finalData$pijCurrent[prodSet$index] <- results$x
      finalData <- finalData[, htilde := first(h) - sum(pijCurrent), by = producerId]
      endTime <- proc.time() - startTime
      update.time <- update.time + endTime[3]
      print(paste("Total Setup Time:", setup.time, ", Total QP time:", qp.time, ", Total Update Time:", update.time))
      print(paste("Total Elapsed Time: ", setup.time + qp.time + update.time))
    }
    
    percentageChangeInPijVal <-  100 * sum(abs(finalData$pijOld - finalData$pijCurrent)) / sum(abs(finalData$pijOld))
    percentageChangeInObjVal <- 100 * abs(objVal - objValPrev) / abs(objValPrev)
    print(paste("Percentage change in pij values: ", percentageChangeInPijVal))
    print(paste("Percentage change in the objective value: ", percentageChangeInObjVal))
    
    finalData$pijOld <- finalData$pijCurrent
    count <- count + 1
  }
  
  print(paste("Final objective value:", objVal))
  output <- subset(finalData, select = c("producerId", "consumerId", "pijBase", "pijCurrent"))
  
  return(list(optimalPij = output, nIter = count, runtime = setup.time + qp.time + update.time, objVal = objVal))
} 