"""
Utility functions.
"""

from collections import OrderedDict

from torch import nn

from actfuns import actfun_name2factory


def count_parameters(model, only_trainable=True):
    r"""
    Count the number of (trainable) parameters within a model and its children.
    Arguments:
        model (torch.nn.Model): the model.
        only_trainable (bool, optional): indicates whether the count should be restricted
            to only trainable parameters (ones which require grad), otherwise all
            parameters are included. Default is ``True``.
    Returns:
        int: total number of (trainable) parameters possessed by the model.
    """
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())


def freeze(model):
    for param in model.parameters():
        param.requires_grad = False
    return model


def unfreeze(model):
    for param in model.parameters():
        param.requires_grad = True
    return model


def add_new_head(model_base, num_classes, **head_params):
    # Get the number of output features in the last embedding and cut off
    # the final layer
    last_module_name = next(reversed(model_base._modules))
    last_module = getattr(model_base, last_module_name)
    if hasattr(last_module, "in_features"):
        n_features = last_module.in_features
    elif isinstance(last_module, nn.Sequential):
        for m in last_module:
            if hasattr(m, "in_features"):
                n_features = m.in_features
                break
        else:
            raise ValueError(
                "No modules in {} of model_base have an in_features attribute."
                " Module:\n{}".format(last_module_name, last_module)
            )
    else:
        raise ValueError(
            "The last module, {}, of model_base is not an nn.Sequential and"
            " does not have an in_features attribute. Last module:\n{}".format(
                last_module_name, last_module
            )
        )

    setattr(model_base, last_module_name, nn.Identity())
    if last_module_name == "head_dist":
        if hasattr(model_base, "head"):
            model_base.head = nn.Identity()
        else:
            raise ValueError("Model has head_dist but not head")

    model_head = MLP(in_channels=n_features, out_channels=num_classes, **head_params)
    model = nn.Sequential(OrderedDict([("base", model_base), ("head", model_head)]))

    return model, model_base, model_head


class MLP(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        n_layer=1,
        hidden_width=None,
        actfun="relu",
        layer_norm=True,
    ):
        super().__init__()

        actfun_factory = actfun_name2factory(actfun)
        _actfun = actfun_factory()

        divisor = getattr(_actfun, "k", 1)
        feature_factor = getattr(_actfun, "feature_factor", 1)

        if hidden_width is None:
            hidden_width = in_channels

        hidden_width = int(int(round(hidden_width / divisor)) * divisor)

        layers = []
        n_current = in_channels
        for i_layer in range(0, n_layer):
            layer = []
            layer.append(nn.Linear(n_current, hidden_width, bias=not layer_norm))
            n_current = hidden_width
            if layer_norm:
                layer.append(nn.LayerNorm(n_current))
            layer.append(actfun_factory())
            n_current = int(round(n_current * feature_factor))
            layers.append(nn.Sequential(*layer))
        self.layers = nn.Sequential(*layers)
        self.classifier = nn.Linear(n_current, out_channels)

    def forward(self, x):
        x = self.layers(x)
        x = self.classifier(x)
        return x
