from torch.nn import Linear
import numpy as np
import torch
import os
import random
from torch_geometric.utils import degree
import torch_geometric.transforms as T
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import BaseTransform
from torch_geometric.data import DataLoader, Data
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import degree, to_undirected
from torch_geometric.utils import dropout_adj
from utils import *
from torch_geometric.utils import to_networkx,dropout_adj,dropout_edge
from grakel import Graph
from torch_geometric.data import Data
from grakel import GraphKernel, graph_from_networkx
import torch.nn as nn
from torch_scatter import scatter

class EarlyStopper:
    def __init__(self, patience, min_delta,file_path,saved=False):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = -np.inf
        self.epoch_counter =0
        self.test_acc_record = 0
        self.file_path = file_path
        self.model_saved=saved

        self.model_updated=False

    def early_stop(self, validation_loss,epoch_num, test_acc_record,model):
        print("max_validation_acc: %f, validation_acc: %f"%(self.min_validation_loss, validation_loss))
        if validation_loss > self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.epoch_counter=epoch_num
            self.counter = 0
            self.test_acc_record = test_acc_record

            if self.model_saved==True:
                torch.save(model,self.file_path)
                self.model_updated=True

        elif validation_loss < (self.min_validation_loss + self.min_delta):
            self.counter += 1
            
        if self.counter >= self.patience:
            return True
            
        return False

'''
Data processing for TUdataset follows https://github.com/JinheonBaek/GMT/blob/main/utils/data.py,
assign one-hot code for data without feature
''' 
class NormalizedDegree(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        deg = degree(data.edge_index[0], dtype=torch.float)
        deg = (deg - self.mean) / self.std
        data.x = deg.view(-1, 1)
        return data


#tool function for PNA
def get_histogram(dataset):
    max_degree = -1
    for data in dataset:
        d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
        max_degree = max(max_degree, int(d.max()))

# Compute the in-degree histogram tensor
    deg = torch.zeros(max_degree + 1, dtype=torch.long)
    for data in dataset:
        d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
        deg += torch.bincount(d, minlength=deg.numel())
    return deg
#----------------------------------------------------------------

def get_node_number(dataset):
    max_num=-1
    for data in dataset:
       
        max_num = data.num_nodes
    max_num=max(max_num, 200)
    return max_num


def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def set_seed(seed):
    """Sets seed"""
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)


def visualize_graph(G, edge_list,file_name):
    """
    Visualize a NetworkX graph with specific edges highlighted.

    Parameters:
    G (networkx.Graph): A NetworkX graph.
    edge_list (list of tuples): A list of edges to highlight.
    """

    if G.is_directed():
        G=G.to_undirected()
    pos = nx.spring_layout(G)  # You can choose other layouts as per your preference
    nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray')

    nx.draw_networkx_edges(G, pos, edgelist=edge_list, edge_color='red', width=2)

    plt.savefig(file_name)
    plt.clf()


def distance_compute(tensor1, tensor2):
    pdist = torch.nn.PairwiseDistance(p=2)
    return pdist(tensor1, tensor2)

def pyg_to_grakel_optimized(pyg_data_list):
    """Convert a list of PyG Data objects directly to GraKeL graphs."""
    G_grakel = []  # List to store GraKeL graphs

    for data in pyg_data_list:
        edge_set = {(int(e[0]), int(e[1])) for e in data.edge_index.t().tolist()}
        node_labels = {i: str(data.x[i].tolist()) if data.x is not None else '0' for i in range(data.num_nodes)}
        edge_labels = {(u, v): '1' for u, v in edge_set}
        G_grakel.append(Graph(edge_set, node_labels=node_labels, edge_labels=edge_labels))
    return G_grakel

def pyg_to_grakel(pyg_data_list):
    grakel_graphs = []
    for data in pyg_data_list:
        # Get the adjacency matrix from the edge index
        edge_index = data.edge_index
        num_nodes = data.num_nodes
        adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
        adj_matrix[edge_index[0], edge_index[1]] = 1
        
        # Get node attributes (if any)
        node_attributes = data.x.numpy() if 'x' in data else None
        
        # Create a GraKeL graph
        graph = Graph(adj_matrix.numpy(), node_labels=node_attributes, )
        grakel_graphs.append(graph)
    
    return grakel_graphs


def dist_compute(origin_data,data1,data2, model,device=None):
    
    origin_embed=model(origin_data.x,origin_data.edge_index,origin_data.batch)
    embed1=model(data1.x,data1.edge_index,data1.batch)
    embed2=model(data2.x,data2.edge_index,data2.batch)
    dist1=distance_compute(origin_embed,embed1)
    dist2=distance_compute(origin_embed,embed2)
    return dist1,dist2




