from spaghettini import quick_register
from abc import ABC

from torch import nn


@quick_register
class OneLayerFullyConnectedCell(nn.Module, ABC):
    def __init__(self, z_dim, activation):
        super().__init__()
        self.z_dim = z_dim

        self.fc = nn.Linear(in_features=z_dim, out_features=z_dim)
        self.act = activation

    def forward(self, z, x, *args):
        z_new = self.fc(self.act(z))
        return z_new + x
