import math, torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

class GraphConvolution(Module):
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features).cuda())
        self.bias = Parameter(torch.FloatTensor(out_features).cuda())
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv); self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj, eye=False):
        output = torch.mm(input, self.weight)
        if eye == False:
            output = torch.spmm(adj, output)
        return output + self.bias # TODO: check if we should integrate BIAS first!

    # def __repr__(self):
    #     return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
