import pdb

import model.biggan.layers as biggan_layers
import torch
from torch import nn


class LinearFinder(nn.Module):
    def __init__(self, z_dim=128, sn=True, init="ortho"):
        super(LinearFinder, self).__init__()
        self.dim_z = z_dim
        self.init = init
        if sn:
            self.which_linear = biggan_layers.SNLinear
        else:
            self.which_linear = nn.Linear
        self.linear_z = self.which_linear(self.dim_z, self.dim_z)
        # Initialize weights. Optionally skip init for testing.
        self.init_weights()

    # Initialize
    def init_weights(self):
        self.param_count = 0
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding):
                if self.init == "ortho":
                    nn.init.orthogonal_(module.weight)
                elif self.init == "N02":
                    nn.init.normal_(module.weight, 0, 0.02)
                elif self.init in ["glorot", "xavier"]:
                    nn.init.xavier_uniform_(module.weight)
                else:
                    print("Init style not recognized...")
                self.param_count += sum([p.data.nelement() for p in module.parameters()])
        print("Param count for F\'s initialized parameters: %d" % self.param_count)

    def forward(self, z, y=None):
        h_noise = self.linear_z(z)

        return h_noise


class ResidualFinder(LinearFinder):
    def __init__(self, z_dim=128, sn=True, init="ortho"):
        super(ResidualFinder, self).__init__(z_dim, sn, init)

    def forward(self, z, y=None):
        h_noise = torch.tanh(self.linear_z(z))
        output = z + h_noise
        return output


class SimpleFinder(nn.Module):
    def __init__(self, z_dim=128, sn=True, init="ortho"):
        super(SimpleFinder, self).__init__()
        self.dim_z = z_dim
        self.init = init
        if sn:
            self.which_linear = biggan_layers.SNLinear
        else:
            self.which_linear = nn.Linear

        # We use a non-spectral-normed embedding here regardless;
        # For some reason applying SN to G's embedding seems to randomly cripple G

        # noise
        self.linear_z = nn.Sequential(
            self.which_linear(self.dim_z, 256),
            nn.LeakyReLU(0.1),
            self.which_linear(256, z_dim),
        )
        # Initialize weights. Optionally skip init for testing.
        self.init_weights()

    # Initialize
    def init_weights(self):
        self.param_count = 0
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.Embedding):
                if self.init == "ortho":
                    nn.init.orthogonal_(module.weight)
                elif self.init == "N02":
                    nn.init.normal_(module.weight, 0, 0.02)
                elif self.init in ["glorot", "xavier"]:
                    nn.init.xavier_uniform_(module.weight)
                else:
                    print("Init style not recognized...")
                self.param_count += sum([p.data.nelement() for p in module.parameters()])
        print("Param count for F\'s initialized parameters: %d" % self.param_count)

    def forward(self, z, y=None):
        h_noise = self.linear_z(z)
        return h_noise


class ResidualMLPFinder(SimpleFinder):
    def __init__(self, z_dim=128, sn=True, init="ortho"):
        super(ResidualMLPFinder, self).__init__(z_dim, sn, init)

    def forward(self, z, y=None):
        h_noise = torch.tanh(self.linear_z(z))
        output = z + h_noise
        return output


class IdentityFinder(SimpleFinder):
    def __init__(self, z_dim=128, sn=True, init="ortho"):
        super(IdentityFinder, self).__init__(z_dim, sn, init)

    def forward(self, z, y=None):
        return z
