from collections import Sequence

import torch
import torch.nn.functional as F
from torch import nn

from torchdrug import core, data, utils, layers
from torchdrug.core import Registry as R

from . import layer


@R.register("model.NBFNet")
class NeuralBellmanFordNetwork(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, num_relation, message_func="distmult", aggregate_func="pna",
                 short_cut=False, layer_norm=False, activation="relu", concat_hidden=False, dependent=False, 
                 project=False, pre_activation=False, rel_norm=False):
        super(NeuralBellmanFordNetwork, self).__init__()

        if not isinstance(hidden_dims, Sequence):
            hidden_dims = [hidden_dims]
        num_relation = int(num_relation)
        self.input_dim = input_dim
        self.output_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1) + input_dim
        self.dims = [input_dim] + list(hidden_dims)
        self.num_relation = num_relation
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden

        self.layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            self.layers.append(layer.GeneralizedRelationalConv(self.dims[i], self.dims[i + 1], num_relation,
                                                               self.dims[0], message_func, aggregate_func, layer_norm,
                                                               activation, dependent, pre_activation, project=project, rel_norm=rel_norm))
        #feature_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1) + input_dim

        self.mlp = layers.MLP(self.output_dim, [self.output_dim] + [1])

    def remove_easy_edges(self, graph, h_index, t_index, r_index=None):
        if self.remove_one_hop:
            h_index_ext = torch.cat([h_index, t_index], dim=-1)
            t_index_ext = torch.cat([t_index, h_index], dim=-1)
            if r_index is not None:
                any = -torch.ones_like(h_index_ext)
                pattern = torch.stack([h_index_ext, t_index_ext, any], dim=-1)
            else:
                pattern = torch.stack([h_index_ext, t_index_ext], dim=-1)
        else:
            if r_index is not None:
                pattern = torch.stack([h_index, t_index, r_index], dim=-1)
            else:
                pattern = torch.stack([h_index, t_index], dim=-1)
        pattern = pattern.flatten(0, -2)
        edge_index = graph.match(pattern)[0]
        edge_mask = ~layers.functional.as_mask(edge_index, graph.num_edge)
        return graph.edge_mask(edge_mask)


    def forward(self, graph, input, rel_representations, all_loss=None, metric=None):
        with graph.node():
            graph.boundary = input

        for i in range(len(self.layers)):
            self.layers[i].relation = rel_representations
        
        hiddens = []
        layer_input = input

        for layer in self.layers:
            hidden = layer(graph, layer_input)
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
            hiddens.append(hidden)
            layer_input = hidden

        node_query = graph.query.expand(graph.num_node, -1, -1)
        if self.concat_hidden:
            node_feature = torch.cat(hiddens + [node_query], dim=-1)
        else:
            node_feature = torch.cat([hiddens[-1], node_query], dim=-1)

        node_feature = self.mlp(node_feature).squeeze(-1)
        return {
            "node_feature": node_feature,
        }


@R.register("model.CompGCN")
class CompositionalGraphConvolutionalNetwork(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, num_relation, message_func="mult", short_cut=False, layer_norm=False,
                 activation="relu", concat_hidden=False):
        super(CompositionalGraphConvolutionalNetwork, self).__init__()

        if not isinstance(hidden_dims, Sequence):
            hidden_dims = [hidden_dims]
        num_relation = int(num_relation)
        self.input_dim = input_dim
        self.output_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1) + input_dim
        self.dims = [input_dim] + list(hidden_dims)
        self.num_relation = num_relation
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden

        self.layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            self.layers.append(layer.CompositionalGraphConv(self.dims[i], self.dims[i + 1], num_relation,
                                                            message_func, layer_norm, activation))
        self.relation = nn.Embedding(num_relation, input_dim)

    def forward(self, graph, input, all_loss=None, metric=None):
        graph.relation_input = self.relation.weight
        hiddens = []
        layer_input = input

        for layer in self.layers:
            hidden = layer(graph, layer_input)
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
            hiddens.append(hidden)
            layer_input = hidden

        node_query = graph.query.expand(graph.num_node, -1, -1)
        if self.concat_hidden:
            node_feature = torch.cat(hiddens + [node_query], dim=-1)
        else:
            node_feature = torch.cat([hiddens[-1], node_query], dim=-1)

        return {
            "node_feature": node_feature,
        }
    
class CustomNBFNetFull(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, num_relation=None, symmetric=False,
                 message_func="distmult", aggregate_func="sum", short_cut=False, layer_norm=False, activation="relu",
                 concat_hidden=False, num_mlp_layer=2, dependent=False, remove_one_hop=False,
                 num_beam=10, path_topk=10,
                 separate_remove_one_hop=False, rand_label=False, pre_activation=False,
                 learn_query=False, **kwargs):
        super(CustomNBFNetFull, self).__init__()

        if not isinstance(hidden_dims, Sequence):
            hidden_dims = [hidden_dims]
        if num_relation is None:
            double_relation = 1
            num_relation = 1
        else:
            num_relation = int(num_relation)
            double_relation = num_relation * 2
        self.dims = [input_dim] + list(hidden_dims)
        self.num_relation = num_relation
        self.symmetric = symmetric
        self.short_cut = short_cut
        self.concat_hidden = concat_hidden
        self.remove_one_hop = remove_one_hop
        self.num_beam = num_beam
        self.path_topk = path_topk
        self.rand_label = rand_label

        self.layers = nn.ModuleList()
        for i in range(len(self.dims) - 1):
            self.layers.append(layer.GeneralizedRelationalConv(self.dims[i], self.dims[i + 1], num_relation, 
                                                               self.dims[0], message_func, aggregate_func, layer_norm,
                                                               activation, dependent, pre_activation))

        #feature_dim = hidden_dims[-1] * (len(hidden_dims) if concat_hidden else 1) + input_dim

        self.learn_query = learn_query
        if learn_query:
            self.learnable_q = nn.Embedding(1, self.dims[0])
        if rand_label:
            self.rand_query = torch.nn.init.normal_(torch.empty(1, self.dims[0], dtype=torch.float))
            self.rand_query.requires_grad = False
    
    @utils.cached
    def bellmanford(self, graph, h_index, separate_grad=False):
        if self.rand_label:
            query = self.rand_query.repeat(h_index.shape[0], 1).to(h_index.device)
        else:
            query = torch.ones(h_index.shape[0], self.dims[0], device=h_index.device, dtype=torch.float)
        if self.learn_query:
            query = self.learnable_q.weight.expand(h_index.shape[0], self.dims[0])
        index = h_index.unsqueeze(-1).expand_as(query)
        # important: DO label nodes uniquely per batch, initialize (batch_size) graphs
        boundary = torch.zeros(graph.num_node, *query.shape, device=self.device)
        boundary.scatter_add_(0, index.unsqueeze(0), query.unsqueeze(0))

        with graph.graph():
            graph.query = query
        with graph.node():
            graph.boundary = boundary

        hiddens = []
        layer_input = boundary

        for layer in self.layers:
            hidden = layer(graph, layer_input)
            if self.short_cut and hidden.shape == layer_input.shape:
                hidden = hidden + layer_input
            hiddens.append(hidden)
            layer_input = hidden

        return {
            "node_feature": hiddens[-1].transpose(1, 0),  # shape: (bs, num_rel, dim) 
        }
    
    def forward(self, graph, h_index, t_index=None, r_index=None, all_loss=None, metric=None):

        if graph.num_relation:
            pass
            # don't transform to undirected for now
        else:
            graph = self.as_relational_graph(graph)

        output = self.bellmanford(graph, h_index)
        feature = output["node_feature"]#.transpose(0, 1)

        return feature  # (bs, num_relations, dim)

@R.register("model.RelNBFNet")
class RelNBFNet(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden, 
                 num_layers=6,
                 rand_label=False, 
                 pre_activation=False,
                 **kwargs):
        super(RelNBFNet, self).__init__(**kwargs)

        self.input_dim = input_dim
        self.ablation_etypes = kwargs.get('ablation_etypes', False)

        self.model = CustomNBFNetFull(
            input_dim=input_dim,
            hidden_dims=[hidden] * num_layers,
            num_relation=4 if not self.ablation_etypes else None,
            aggregate_func="sum",
            layer_norm=True,
            short_cut=True,
            learn_query=kwargs.get('learn_query', False),
            rand_label=rand_label,
            pre_activation=pre_activation
        )
        self.hidden_dim = hidden
        
        if self.hidden_dim != self.input_dim:
            self.input_transform_linear = nn.Linear(self.input_dim, self.hidden_dim)


    def forward(self, graph, r_idx, all_loss=None, metric=None):
        # TODO: consider adding RWSE features to all-ones
        # x = input.unsqueeze(0)

        x = self.model(graph, h_index=r_idx)

        return x 
        # return {
        #     "graph_feature": None,
        #     "node_feature": x  # [num_rel, dim] or [bs, num_rel, dim]
        # }