import torch
import torch.nn as nn


class GeneralMLP(nn.Module):
    def __init__(self, bias, relu, batch_norm, num_layers, input_dims, num_classes, num_per_class, **kwargs):
        super(GeneralMLP, self).__init__()

        self.num_samples = num_classes * num_per_class
        self.num_per_class = num_per_class
        self.num_layers = num_layers
        self.relu = relu
        self.batch_norm = batch_norm
        self.input_dims = input_dims
        self.bias = bias
        self.wds = kwargs['weight_decays']

        weight0 = nn.Parameter(nn.Linear(in_features=self.num_samples,
                                         out_features=input_dims[0], bias=False).weight.data)
        self.weight0 = weight0

        # Creating layer names
        weight_names = []
        relu_names = None
        batch_norm_names = None
        if relu:
            relu_names = []
        if batch_norm:
            batch_norm_names = []

        for layer_num in range(num_layers):
            weight_names.append('weight' + str(layer_num+1))
            if relu and layer_num != num_layers-1:
                relu_names.append('relu' + str(layer_num+1))
            if batch_norm and layer_num != num_layers-1:
                batch_norm_names.append('bn' + str(layer_num+1))

        self.weight_names = weight_names
        self.relu_names = relu_names
        self.batch_norm_names = batch_norm_names

        # Defining the layers
        for layer_num in range(num_layers):
            setattr(self, weight_names[layer_num], nn.Linear(in_features=self.input_dims[layer_num],
                                                             out_features=self.input_dims[layer_num+1], bias=bias))
            if relu and layer_num != num_layers-1:
                setattr(self, relu_names[layer_num], nn.ReLU())
            if batch_norm and layer_num != num_layers-1:
                setattr(self, batch_norm_names[layer_num], nn.BatchNorm1d(num_features=layer_num+1))

        self._initialize_weights(**kwargs)

        # Setting the hook activity determiners
        self.do_hook = True
        # Registering backward hooks
        self.current_grad_norms = torch.zeros(num_layers+1)
        self.current_reg_grad_norms = torch.zeros(num_layers+1)
        self.current_fit_grad_norms = torch.zeros(num_layers + 1)
        self.current_grad_angle = torch.zeros(num_layers + 1)
        self.weight0.register_hook(self.register_backward_hookie(0))
        for layer_idx in range(num_layers):
            getattr(self, weight_names[layer_idx]).weight.register_hook(self.register_backward_hookie(layer_idx+1))

    def forward(self):
        features = self.weight0.T

        for layer_num in range(self.num_layers):
            features = getattr(self, self.weight_names[layer_num])(features)
            if self.batch_norm and layer_num != self.num_layers-1:
                features = getattr(self, self.batch_norm_names[layer_num])(features)
            if self.relu and layer_num != self.num_layers-1:
                features = getattr(self, self.relu_names[layer_num])(features)

        return features

    def forward_until_layer(self, until_layer_n):

        features = self.weight0.T

        for layer_num in range(until_layer_n):
            features = getattr(self, self.weight_names[layer_num])(features)
            if self.batch_norm and layer_num != self.num_layers - 1:
                features = getattr(self, self.batch_norm_names[layer_num])(features)
            if self.relu and layer_num != self.num_layers - 1:
                features = getattr(self, self.relu_names[layer_num])(features)

        return features

    def forward_until_layer_clean(self, until_layer_n):

        features = self.weight0.T

        for layer_num in range(until_layer_n-1):
            features = getattr(self, self.weight_names[layer_num])(features)
            if self.batch_norm and layer_num != self.num_layers - 1:
                features = getattr(self, self.batch_norm_names[layer_num])(features)
            if self.relu and layer_num != self.num_layers - 1:
                features = getattr(self, self.relu_names[layer_num])(features)

        if until_layer_n >= 1:
            features = getattr(self, self.weight_names[until_layer_n-1])(features)

        return features

    def _initialize_weights(self, **kwargs):
        weight_decays = kwargs['weight_decays']
        scaling = kwargs['scaling']
        dist = kwargs['dist']
        if dist == 'gaussian':
            with torch.no_grad():
                self.weight0.data = scaling*torch.randn_like(self.weight0)/(weight_decays[0]*self.num_samples)**(1/2)
                for layer_idx in range(self.num_layers):
                    getattr(self, self.weight_names[layer_idx]).weight.data = \
                        scaling*torch.randn_like(getattr(self, self.weight_names[layer_idx]).weight.data) / \
                        weight_decays[layer_idx+1]**(1/2)
        elif dist == 'kaiming':
            with torch.no_grad():
                nn.init.kaiming_uniform_(self.weight0.data, nonlinearity='relu')
                self.weight0.data *= scaling
                for layer_idx in range(self.num_layers):
                    nn.init.kaiming_uniform_(getattr(self, self.weight_names[layer_idx]).weight.data,
                                             nonlinearity='relu')
                    getattr(self, self.weight_names[layer_idx]).weight.data *= scaling
        else:
            raise NotImplementedError

    def perturb_weights(self, scaling):
        with torch.no_grad():
            self.weight0.data += scaling * torch.randn_like(self.weight0) / \
                                 torch.sqrt(torch.tensor([self.weight0.shape[0] * self.weight0.shape[1]])) * \
                                 torch.frobenius_norm(self.weight0)
            for layer_idx in range(self.num_layers):
                getattr(self, self.weight_names[layer_idx]).weight.data += \
                    scaling * torch.randn_like(getattr(self, self.weight_names[layer_idx]).weight.data) / \
                    torch.sqrt(torch.tensor([getattr(self, self.weight_names[layer_idx]).weight.data.shape[0] *
                               getattr(self, self.weight_names[layer_idx]).weight.data.shape[1]])) * \
                    torch.frobenius_norm(getattr(self, self.weight_names[layer_idx]).weight.data)

    def register_backward_hookie(self, layer_idx):
        def hookie(grad):
            if self.do_hook:
                if layer_idx == 0:
                    reg_grad = self.wds[layer_idx] * self.weight0.data
                    grad_total = reg_grad+grad
                    self.current_grad_norms[layer_idx] = torch.frobenius_norm(grad_total)
                    self.current_reg_grad_norms[layer_idx] = torch.frobenius_norm(reg_grad)
                    self.current_fit_grad_norms[layer_idx] = torch.frobenius_norm(grad)
                    self.current_grad_angle[layer_idx] = torch.sum(reg_grad*grad)/torch.frobenius_norm(grad) / \
                        torch.frobenius_norm(reg_grad)
                else:
                    reg_grad = self.wds[layer_idx] * getattr(self, self.weight_names[layer_idx-1]).weight.data
                    grad_total = reg_grad+grad
                    self.current_grad_norms[layer_idx] = torch.frobenius_norm(grad_total)
                    self.current_reg_grad_norms[layer_idx] = torch.frobenius_norm(reg_grad)
                    self.current_fit_grad_norms[layer_idx] = torch.frobenius_norm(grad)
                    self.current_grad_angle[layer_idx] = torch.sum(reg_grad*grad)/torch.frobenius_norm(grad) / \
                        torch.frobenius_norm(reg_grad)
                return grad
            else:
                return grad
        return hookie


class SingleFCReLULayer(nn.Module):

    def __init__(self, in_dim, out_dim, bias, leaky):
        super(SingleFCReLULayer, self).__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.leaky = leaky
        self.bias = bias

        self.fc1 = nn.Linear(in_dim, out_dim, bias=bias)
        if self.leaky:
            self.relu = nn.LeakyReLU()
        else:
            self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.fc1(x))


class SingleConvReLULayer(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, bias, leaky):
        super(SingleConvReLULayer, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.leaky = leaky
        self.bias = bias

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
        if self.leaky:
            self.relu = nn.LeakyReLU()
        else:
            self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.conv1(x))
