import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

# =========================================
# Utility functions
# =========================================

def normalize_pos(pos):    
    """
    Normalizes a nx pos dictionary to be within a [0,1]^2 box
    """
    x = np.array([pos[node][0] for node in pos.keys()])
    x = (x - min(x))/(max(x) - min(x) + 1e-10)
    y = np.array([pos[node][1] for node in pos.keys()])
    y = (y - min(y))/(max(y) - min(y) + 1e-10)    
    return dict(zip(pos.keys(), list(zip(x,y))))

# =========================================
# Generation functions
# =========================================

def glue_graphs(graph1, graph2): 

    graph1 = graph1.copy()
    graph2 = graph2.copy()

    for node in graph1.nodes:
        graph1.nodes[node]['owner'] = 'alice'
        graph1.nodes[node]['label'] = f'{node+1}'

    for node in graph2.nodes:
        graph2.nodes[node]['owner'] = 'bob'
        graph2.nodes[node]['label'] = f'{node+1}'

    graph = nx.algorithms.disjoint_union(graph1, graph2)
    
    glue_edge = np.random.choice(graph1.nodes), np.random.choice(graph2.nodes)+graph1.number_of_nodes()
    graph.add_edge(*glue_edge)

    return graph, glue_edge
    
    
def sample_glued_graph(graph_universe):
    
    n_universe = len(graph_universe)

    #  sample two graphs
    graph1_idx, graph2_idx = np.random.choice(n_universe, 2)
    graph1 = graph_universe[graph1_idx]
    graph2 = graph_universe[graph2_idx]

    # glue them 
    graph, glue_edge = glue_graphs(graph1, graph2)
        
    # randomly permute nodes
    n_graph_nodes = graph1.number_of_nodes()
    p = np.append(np.random.permutation(n_graph_nodes), n_graph_nodes+np.random.permutation(n_graph_nodes))
    A = nx.adjacency_matrix(graph)[p,:][:,p].todense()
    graph_relabeled = nx.from_numpy_matrix(A)
    for node in graph_relabeled.nodes:    
        graph_relabeled.nodes[node]['owner'] = graph.nodes[p[node]]['owner']
        
        if graph_relabeled.nodes[node]['owner'] == 'alice':
            graph_relabeled.nodes[node]['is_root'] = p[node] == glue_edge[0]
        else:
            graph_relabeled.nodes[node]['is_root'] = p[node] == glue_edge[1]
            

    # compute the isomorphism class
    isomorphism_class = 0
    if graph1_idx == graph2_idx: 
        isomorphism_class = str(graph1_idx)
    else: 
        if graph1_idx < graph2_idx: 
            isomorphism_class = str(graph1_idx) + str(graph2_idx)
        else:
            isomorphism_class = str(graph2_idx) + str(graph1_idx)
    
    return graph_relabeled, isomorphism_class, graph1_idx, graph2_idx


# =========================================
# Pytorch functions
# =========================================


def glued_dataset_to_torch(dataset, max_degree=None, unique_ids=False):
    
    import torch
    from torch_geometric.data import Data

    if max_degree is None:         
        degrees = []
        for datum in dataset:
            degrees.extend([deg for _, deg in nx.degree(datum['graph'])])
        max_degree = np.int(max(degrees))+1
        print(f'setting max degree to: {max_degree}')

    dataset_torch = []
    for datum in dataset:

        graph, label = datum['graph'], datum['label']

        edge_index = np.reshape(np.array([([edge[0], edge[1], edge[1], edge[0]]) for edge in nx.to_edgelist(graph)]),(-1,2))
        edge_index = torch.tensor(edge_index.T, dtype=torch.long)

        if unique_ids: 
            x = np.zeros((graph.number_of_nodes(), max_degree+2+graph.number_of_nodes()), dtype=np.float)
        else: 
            x = np.zeros((graph.number_of_nodes(), max_degree+2), dtype=np.float)
            
        for node in graph.nodes:

            # reveal the owner of each node
            if graph.nodes[node]['owner'] == 'alice':
                x[node,0] = 0
            else: 
                x[node,0] = 1

            # reveal the roots
            if graph.nodes[node]['is_root'] and graph.nodes[node]['owner'] == 'alice':
                x[node,1] = 1
                
            # reveal the degree (one-hot encoded)
            x[node,2:max_degree+2] = np.eye(max_degree)[nx.degree(graph, node)]
            
            # node ids
            if unique_ids: 
                x[node,max_degree+2:] = np.eye(graph.number_of_nodes())[node]

        x = torch.tensor(x, dtype=torch.float) 

        # graph label
        y = torch.tensor([np.where(label)[0][0]], dtype=torch.long)

        dataset_torch.append(Data(x=x, edge_index=edge_index, edge_attr=None, y=y))
        
    return dataset_torch

# =========================================
# Drawing functions
# =========================================

def draw_glued_graph_paper(graph, figsize=(6, 5)):
    """
    Draws a glued graph nicely (for paper illustration).
    """
    
    colors = [[0,0,1] for i in graph.nodes]
    for i in graph.nodes: 
        if graph.nodes[i]['owner'] == 'alice': 
            colors[i] = plt.cm.RdYlGn(0.35) #np.array([35, 100, 119])/255
        else: 
            colors[i] = plt.cm.RdYlGn(0.9) #np.array([124, 173, 62])/255
            
    if 'pos' in graph.nodes[0]:
        labels = dict(zip(graph.nodes, [graph.nodes[i]['label'] for i in graph.nodes]))
        pos = [graph.nodes[i]['pos'] for i in graph.nodes]
    else: 
        labels = dict(zip(graph.nodes, [i for i in graph.nodes]))
        pos = nx.layout.spring_layout(graph) 
        
    fig = plt.figure(figsize=figsize, facecolor=[1,1,1])
    ax = fig.add_subplot(1, 1, 1)
    nx.draw_networkx_edges(graph, pos, alpha=0.85, width=1.5, ax=ax)
    nx.draw_networkx_nodes(graph, pos, node_size=200, alpha=1, node_color=colors, linewidths=1.25, edgecolors=[0.12,0.12,0.12], ax=ax)
    plt.axis('equal')
    plt.axis('off')
    return fig

def draw_glued_graph(graph, figsize=(16, 10)):
    """
    Draws a glued graph.
    """
  
    if 'pos' in graph.nodes[0]:
        colors = [graph.nodes[i]['color'] for i in graph.nodes]
        labels = dict(zip(graph.nodes, [graph.nodes[i]['label'] for i in graph.nodes]))
        pos = [graph.nodes[i]['pos'] for i in graph.nodes]
    else: 
        colors = []
        for node in graph.nodes: 
            if graph.nodes[node]['owner'] == 'alice':
                colors.append('b')
            else:
                colors.append('g')
        labels = dict(zip(graph.nodes, [i for i in graph.nodes]))
        pos = nx.layout.spring_layout(graph) 
    
    plt.figure(figsize=figsize)
    nx.draw_networkx_edges(graph, pos, alpha=0.8, width=1.5)
    nx.draw_networkx_nodes(graph, pos, node_size=700, alpha=0.95, node_color=colors, edgecolors=[0,0,0])
    nx.draw_networkx_labels(graph, pos, alpha=1, labels=labels, font_color=[1,1,1], font_size=12)
    plt.axis('equal')
    plt.axis('off')
    plt.show()
