from spaghettini import quick_register
from functools import partial

from torch import nn
import torch.nn.functional as F

from torch.nn.utils import weight_norm


@quick_register
class ResNetLayer(nn.Module):
    def __init__(self, n_channels, n_inner_channels, kernel_size=3, num_groups=8, init_std=0.01, use_weight_norm=True):
        super().__init__()
        # Determine if we should wrap each convolution operation with a weight normalization operation.
        wn = partial(weight_norm, name="weight", dim=0) if use_weight_norm else lambda x: x

        # Construct the layers and normalizers.
        self.conv1 = wn(nn.Conv2d(n_channels, n_inner_channels, (kernel_size, kernel_size),
                                  padding=kernel_size // 2, bias=False))
        self.conv2 = wn(nn.Conv2d(n_inner_channels, n_channels, (kernel_size, kernel_size),
                                  padding=kernel_size // 2, bias=False))

        self.norm1 = nn.GroupNorm(num_groups, n_inner_channels)
        self.norm2 = nn.GroupNorm(num_groups, n_channels)
        self.norm3 = nn.GroupNorm(num_groups, n_channels)

        # Initialize the parameters.
        self.conv1.weight.data.normal_(0, init_std)
        self.conv2.weight.data.normal_(0, init_std)

    def forward(self, z, x):
        y = F.relu(self.norm1(self.conv1(z)))
        return self.norm3(F.relu(z + self.norm2(x + self.conv2(y))))
