##1. Packages loading
##----------------------------------------------------------------------------------
library(PLNet)
library(orthopolynom)
library(Rcpp)
library(CVXR)
library(parallel)
library(Seurat)
library(SeuratData)
library(GMPR)
library(PLNmodels)
##----------------------------------------------------------------------------------


##2. Data loading and pre-process
##----------------------------------------------------------------------------------
##2.1 Install the data
SeuratData::InstallData("ifnb")
data("ifnb")
count_mat<-(ifnb@assays$RNA@counts)
anno_vec<-ifnb$seurat_annotations
anno_vec_2<-ifnb$orig.ident

##2.2 Choose the cells of CD14 celltype
CD14_stim<-(which(anno_vec == names(table(anno_vec))[1] & anno_vec_2 == names(table(anno_vec_2))[2]))
count_CD14_stim<-count_mat[,CD14_stim]

##2.3 Choose 200 high variable genes as gene set of interest
P3se_ifnb_stim = CreateSeuratObject(counts = count_CD14_stim,min.cells = 3)
P3se_ifnb_stim <- NormalizeData(P3se_ifnb_stim, normalization.method = "LogNormalize", scale.factor = 10000)
P3se_ifnb_stim <- FindVariableFeatures(P3se_ifnb_stim,nfeatures = 200)
variable_gene_stim<-Seurat::VariableFeatures(P3se_ifnb_stim)
count_CD14_stim_ifnb<-count_CD14_stim[which(row.names(count_CD14_stim)%in%variable_gene_stim),]
datafinal.stim<-count_CD14_stim_ifnb[which(rowSums(ifelse(as.matrix(count_CD14_stim_ifnb)>1,1,0))>=1),]
datafinal.stim<-as.matrix(datafinal.stim)
out_smooth<-datafinal.stim

##2.4 Estimate the library size by GMPR method
S_depth<-GMPR(t(out_smooth),1,1)
##----------------------------------------------------------------------------------

##3. Run PLNet
##----------------------------------------------------------------------------------
##3.1 Estimate the convariance matrix by MLE estimator
k_max<-10
time_1<-Sys.time()
cov_input<-mle_newton(data_use = t(as.matrix(out_smooth)),
                      S_depth = S_depth,
                      k_max = k_max,
                      core_num = 20)
save(cov_input,file = "./cov_input.Rdata")
time_2<-Sys.time()
time_2 - time_1


##3.2 Estimate the precision matrix by Dtrace loss
time_1<-Sys.time()
PLNet_res_list<-list()
for(if_penalize.diagonal in c(TRUE,FALSE)){
  PLNet_res<-PLNet_main(obs_mat = t(out_smooth),
                        Sd_est = S_depth,
                        n_lambda = 100,
                        penalize.diagonal = if_penalize.diagonal,
                        cov_input = cov_input$mlesigmahat[[2]],
                        weight_mat = NULL,zero_mat = NULL,
                        core_num = 1
  )
  PLNet_res_list[[ifelse(if_penalize.diagonal == TRUE,"penalize.diagonal","not penalize.diagonal")]]<-PLNet_res
}
time_2<-Sys.time()
time_2 - time_1
##
save(PLNet_res_list,file = "./PLNet_res_list.Rdata")

##----------------------------------------------------------------------------------

##4. Run VPLN
##----------------------------------------------------------------------------------
time_1<-Sys.time()
VPLN_res_list<-list()
for(if_penalize.diagonal in c(TRUE,FALSE)){
  original_list<-list(data_1=as.data.frame(as.matrix(t(out_smooth))),
                      Covariate=as.data.frame(matrix(rep(0,dim(as.matrix(t(out_smooth)))[1]),ncol = 1)))
  rownames(original_list$Covariate)<-rownames(original_list$data_1)
  pre_data<-prepare_data(counts = original_list$data_1,covariates = original_list$Covariate,offset = S_depth)
  names(pre_data)[2]<-"covariates"
  fits <- PLNnetwork(Abundance ~ 1 + offset(log(Offset)), data = pre_data,
                     control_init = list(nPenalties=100,min.ratio=1e-4),
                     control_main = list(xtol_rel = 1e-2, penalize_diagonal = if_penalize.diagonal))
  
  
  VPLN_res_list[[ifelse(if_penalize.diagonal == TRUE,"penalize.diagonal","not penalize.diagonal")]]<-fits
}
time_2<-Sys.time()
time_2 - time_1
##
save(VPLN_res_list,file = "./VPLN_res_list_GMPR.Rdata")
##----------------------------------------------------------------------------------