import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from scipy import stats
from scipy.stats import spearmanr
from tqdm import tqdm
from train import train_splits, validate_splits
from train_dist import RankNetLoss



def remove_diagonal_elements_torch(matrix):
    n = matrix.size(0)
    mask = ~torch.eye(n, dtype=bool)
    new_matrix = matrix[mask].reshape(n, n - 1)
    return new_matrix

def compute_mean_spearmanr_correlation(A, B):
    if A.size() != B.size() or A.size(0) != A.size(1):
        raise ValueError("Both A and B must be square matrices of the same size.")
    
    n = A.size(0)
    
    # Convert A and B to numpy once if they are tensors, to avoid multiple conversions
    if isinstance(A, torch.Tensor):
        A = A.numpy()
    if isinstance(B, torch.Tensor):
        B = B.numpy()
    
    correlations = []
    for i in range(n):
        a = A[i]
        b = B[i]
        
        # Compute standard deviations using numpy to avoid issues with constant arrays
        if np.std(a) == 0 or np.std(b) == 0:
            continue
        
        # Compute the Spearman correlation for this pair of rows
        rs, _ = spearmanr(a, b)
        correlations.append(rs)
    
    return np.mean(correlations)

def load_dataset(dataset_name,split_num):
    dataset_path = f"./datasplit/{dataset_name}/split{str(split_num)}.pt"
    dataset = torch.load(dataset_path)
    return dataset


dataset_list=["NCI109", 'NCI1', 'IMDB-BINARY', 'DD', 'PROTEINS']
model_list=["GCN", "GIN", "GraphSAGE", "GTransformer", "GMT"]

frame_rank = [[''] * len(dataset_list) for _ in range(len(model_list))]
frame_norank = [[''] * len(dataset_list) for _ in range(len(model_list))]

for row,model_name in enumerate(model_list):
    for col,data_name in enumerate(dataset_list):
        listrank=torch.load("./similarity/"+data_name+"/"+model_name+".pt")
        listnorank=torch.load("./norank_similarity/"+data_name+"/"+model_name+".pt")
 
        rank=[]
        norank=[]
        for i in tqdm(range(5)):
            dataset=load_dataset(data_name ,i)
            test_data=dataset[int(len(dataset)*(train_splits)):int(len(dataset)*(train_splits+validate_splits))]
            true_label=[data.y for data in test_data]
            # index_begin=len(listrank[0])-2
            for j in range(len(listrank[0])-1):
                correlation1=compute_mean_spearmanr_correlation(listrank[i][j],listrank[i][j+1])
                correlation2=compute_mean_spearmanr_correlation(listnorank[i][j],listnorank[i][j+1])
                # print("Layer J:",j)
                loss=RankNetLoss()
                loss_norank=loss(listnorank[i][j],listnorank[i][j+1])
                loss_rank=loss(listrank[i][j],listrank[i][j+1])

                # print("Rankloss:",loss_rank)
                # print("Norankloss:",loss_norank)
                norank.append(correlation2)
                rank.append(correlation1)
                       
        mean_norank=round(np.mean(norank),3)
        mean_rank=round(np.mean(rank),3)
        std_norank=round(np.std(norank),3)
        std_rank=round(np.std(rank),3)

        print(model_name,data_name)
        print("Rank:",rank)
        print("Norank:",norank)

        frame_rank[row][col]=str(mean_rank)+"±"+str(std_rank)
        frame_norank[row][col]=str(mean_norank)+"±"+str(std_norank)


data_df = pd.DataFrame(frame_rank, index=model_list, columns=dataset_list)

data_frame_norank=pd.DataFrame(frame_norank,index=model_list,columns=dataset_list)

print("all layer")
print("with rankloss")
print(data_df)
print("-----------------------------")
print("without rankloss")
print(data_frame_norank)