##Benchmark analysis for real data
library(SeuratData)
##----------------------------------------------------------------------------------


##1. Install the data
##----------------------------------------------------------------------------------
SeuratData::InstallData("ifnb")
data("ifnb")
count_mat<-(ifnb@assays$RNA@counts)
anno_vec<-ifnb$seurat_annotations
anno_vec_2<-ifnb$orig.ident
##----------------------------------------------------------------------------------

##2. Choose the cells of CD14 celltype
##----------------------------------------------------------------------------------
CD14_stim<-(which(anno_vec == names(table(anno_vec))[1] & anno_vec_2 == names(table(anno_vec_2))[2]))
count_CD14_stim<-count_mat[,CD14_stim]
CD14_control<-(which(anno_vec == names(table(anno_vec))[1] & anno_vec_2 == names(table(anno_vec_2))[1]))
count_CD14_control<-count_mat[,CD14_control]
##----------------------------------------------------------------------------------

##3. Choose non-TFs in 200 high variable genes and  TFs in 500 high variable genes as gene set of interest
##----------------------------------------------------------------------------------
load("./TF_list_all.Rdata")
P3se_ifnb_stim = CreateSeuratObject(counts = count_CD14_stim,min.cells = 3)
P3se_ifnb_stim <- NormalizeData(P3se_ifnb_stim, normalization.method = "LogNormalize", scale.factor = 10000)
P3se_ifnb_stim <- FindVariableFeatures(P3se_ifnb_stim,nfeatures = 500)
variable_gene_stim<-Seurat::VariableFeatures(P3se_ifnb_stim)
gene_names<-unique(c(variable_gene_stim[which(variable_gene_stim %in% TF_list_all[[1]])],variable_gene_stim[1:200]))
gene_names<-c(gene_names[which(gene_names %in% TF_list_all[[1]])],setdiff(gene_names,gene_names[which(gene_names %in% TF_list_all[[1]])]))
count_CD14_stim_ifnb<-count_CD14_stim[gene_names,]
datafinal.stim<-count_CD14_stim_ifnb[which(rowSums(ifelse(as.matrix(count_CD14_stim_ifnb)>1,1,0))>=1),]
datafinal.stim<-as.matrix(datafinal.stim)
gene_names<-rownames(datafinal.stim)
TF_in<-gene_names[which(gene_names %in% TF_list_all[[1]])]
##----------------------------------------------------------------------------------

##4. Load the database (htftarget) as golden standard
##----------------------------------------------------------------------------------
gold_standard_network<-matrix(0,nrow = length(gene_names),ncol = length(gene_names))
file_names<-dir("./htftarget")
TF_use_htftarget<-as.vector(Reduce("rbind",strsplit(file_names,".t"))[,1])
for(tf_index in 1:length(TF_use_htftarget)){
  data_htftarget<-read.table(paste("./htftarget/",file_names[tf_index],sep = ""),quote = "",fill = TRUE,header = TRUE,row.names = NULL) 
  target_ori<-as.vector(as.matrix(data_htftarget[,3]))
  index_in<-which(gene_names %in% target_ori)
  gold_standard_network[which(TF_in == TF_use_htftarget[tf_index]),index_in]<-1
}
gold_standard_network<-gold_standard_network + t(gold_standard_network)
gold_standard_network<-ifelse(gold_standard_network>0,1,0)
(table(gold_standard_network[upper.tri(gold_standard_network)]))
prop.table(table(gold_standard_network[upper.tri(gold_standard_network)]))
##----------------------------------------------------------------------------------

##5. Calculate the number of true edges estimated by two methods with different density levels
##----------------------------------------------------------------------------------
##5.1 Load the estimated network of PLNet and VPLN
##PLNet
load("./PLNet_res_list_benchmark.Rdata")
pcor_PLNet<-as.matrix((-1) * PLNet_res_list$penalize.diagonal$Omega_chooseB/(matrix(sqrt(diag(as.matrix(PLNet_res_list$penalize.diagonal$Omega_chooseB))),ncol = 1) %*% matrix(sqrt(diag(as.matrix(PLNet_res_list$penalize.diagonal$Omega_chooseB))),nrow = 1)))
diag(pcor_PLNet)<-1
Eval_PLNet<-Eval_fun(weight_mat = pcor_PLNet,
                     adjoint_true = gold_standard_network,
                     p_TF = length(TF_in))

##VPLN
load("./VPLN_res_list_benchmark.Rdata")
precision_VPLN<-VPLN_res_list$penalize.diagonal$models[[which.max(VPLN_res_list$penalize.diagonal$criteria$BIC)]]$model_par$Omega
pcor_VPLN<-as.matrix((-1) * precision_VPLN/(matrix(sqrt(diag(as.matrix(precision_VPLN))),ncol = 1) %*% matrix(sqrt(diag(as.matrix(precision_VPLN))),nrow = 1)))
diag(pcor_VPLN)<-1
Eval_VPLN<-Eval_fun(weight_mat = pcor_VPLN,
                     adjoint_true = gold_standard_network,
                     p_TF = length(TF_in))


density_choose_vec<-seq(from = 0,to = 1,length.out = 101)
density_choose_vec<-density_choose_vec[-1]
PLNet_vec<-c()
VPLN_vec<-c()
for(density_use in density_choose_vec){
  density_choose_PLNet<-which.min(abs(density_PLNet - density_use))
  pcor_PLNet<-as.matrix((-1) * PLNet_res_list$penalize.diagonal$Omega_est[[density_choose_PLNet]]/(matrix(sqrt(diag(as.matrix(PLNet_res_list$penalize.diagonal$Omega_est[[density_choose_PLNet]]))),ncol = 1) %*% matrix(sqrt(diag(as.matrix(PLNet_res_list$penalize.diagonal$Omega_est[[density_choose_PLNet]]))),nrow = 1)))
  
  density_choose_VPLN<-which.min(abs(density_VPLN - density_use))
  precision_VPLN<-VPLN_res_list$penalize.diagonal$models[[density_choose_VPLN]]$model_par$Omega
  pcor_VPLN<-as.matrix((-1) * precision_VPLN/(matrix(sqrt(diag(as.matrix(precision_VPLN))),ncol = 1) %*% matrix(sqrt(diag(as.matrix(precision_VPLN))),nrow = 1)))

  PLNet_vec<-c(PLNet_vec,length(which(ifelse(part_of_interest(pcor_PLNet * gold_standard_network,p_TF = p_TF)!=0,1,0) == 1)))
  VPLN_vec<-c(VPLN_vec,length(which(ifelse(part_of_interest(pcor_VPLN * gold_standard_network,p_TF = p_TF)!=0,1,0) == 1)))
}
##
summary_mat<-rbind(PLNet_vec[1:10],VPLN_vec[1:10])
rownames(summary_mat)<-c("PLNet","VPLN")
colnames(summary_mat)<-paste("Density_",density_choose_vec[1:10],sep = "")
summary_mat
##----------------------------------------------------------------------------------