import torch
import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_

from actfuns import *


class FCNN(nn.Module):
    def __init__(self, nc_in, nc_hidden, actfun):
        super().__init__()

        self.act1 = actfun_name2factory(actfun)()
        self.act2 = actfun_name2factory(actfun)()
        divisor = getattr(self.act1, "divisor", 1)
        feature_factor = getattr(self.act1, "feature_factor", 1)

        nc_preact = int(int(round(nc_hidden / divisor)) * divisor)
        if nc_preact != nc_hidden:
            print(
                "Warning: Number of hidden units rounded from {} to {}".format(
                    nc_hidden, nc_preact
                )
            )
        nc_postact = int(round(nc_hidden * feature_factor))

        self.fc1 = nn.Linear(nc_in, nc_preact)
        self.fc2 = nn.Linear(nc_postact, nc_preact)
        self.out = nn.Linear(nc_postact, 2)
        self.init_weights()

    def init_weights(self):
        xavier_uniform_(self.fc1.weight)
        xavier_uniform_(self.fc2.weight)
        xavier_uniform_(self.out.weight)
        constant_(self.fc1.bias, 0.0)
        constant_(self.fc2.bias, 0.0)
        constant_(self.out.bias, 0.0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act1(x)
        x = self.fc2(x)
        x = self.act2(x)
        x = self.out(x)
        return x
