import torch.nn as nn
from dgl.nn.pytorch import GraphConv


class FusionModel(nn.Module):
    def __init__(self, in_features=128):
        super(FusionModel, self).__init__()
        self.in_features = in_features
        self.metricLayer = GraphConv(self.in_features, in_features, activation=None, allow_zero_in_degree=True)

    def forward(self, graph, x):
        x = self.metricLayer(graph, x)
        return x

