"""Additional layers not included in PyTorch."""

import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.modules.conv as conv
from typing import Optional
from typing_extensions import Literal
import math
import torch.nn.init as init


class PositionalEncoding(nn.Module):
    """Add a positional encoding as two additional channels to the data."""

    def __init__(self):
        super().__init__()

    def forward(self, x):
        pos = torch.stack(
            torch.meshgrid(
                torch.arange(x.shape[-2], dtype=torch.float, device=x.device),
                torch.arange(x.shape[-1], dtype=torch.float, device=x.device),
            ),
            0,
        )
        pos /= torch.max(pos) + 1e-12
        pos = torch.repeat_interleave(pos.unsqueeze(0), len(x), 0)

        return torch.cat((pos, x), 1)


class Lambda(nn.Module):
    """Apply a lambda function to the input."""

    def __init__(self, f):
        super().__init__()
        self.f = f

    def forward(self, *args, **kwargs):
        return self.f(*args, **kwargs)


class Flatten(Lambda):
    """Flatten the input data after the batch dimension."""

    def __init__(self):
        super().__init__(lambda x: x.view(len(x), -1))


class RescaleLayer(nn.Module):
    """Normalize the data to a hypersphere with fixed/variable radius."""

    def __init__(
        self, init_r=1.0, fixed_r=False, mode: Optional[Literal["eq", "leq"]] = "eq"
    ):
        super().__init__()
        self.fixed_r = fixed_r
        assert mode in ("leq", "eq")
        self.mode = mode
        if fixed_r:
            self.r = torch.ones(1, requires_grad=False) * init_r
        else:
            self.r = nn.Parameter(torch.ones(1, requires_grad=True) * init_r)

    def forward(self, x):
        if self.mode == "eq":
            x = x / torch.norm(x, dim=-1, keepdim=True)
            x = x * self.r.to(x.device)
        elif self.mode == "leq":
            norm = torch.norm(x, dim=-1, keepdim=True)
            x[norm > self.r] /= torch.norm(x, dim=-1, keepdim=True) / self.r

        return x


class SoftclipLayer(nn.Module):
    """Normalize the data to a hyperrectangle with fixed/learnable size."""

    def __init__(self, n, init_abs_bound=1.0, fixed_abs_bound=True):
        super().__init__()
        self.fixed_abs_bound = fixed_abs_bound
        if fixed_abs_bound:
            self.max_abs_bound = torch.ones(n, requires_grad=False) * init_abs_bound
        else:
            self.max_abs_bound = nn.Parameter(
                torch.ones(n, requires_grad=True) * init_abs_bound
            )

    def forward(self, x):
        x = torch.sigmoid(x)
        x = x * self.max_abs_bound.to(x.device).unsqueeze(0)

        return x


class ElementwiseLinear(nn.Module):
    def __init__(self, dim, in_features, out_features, shared_weights=False, bias=True):
        super(ElementwiseLinear, self).__init__()
        if shared_weights:
            dim = 1  # same for all dimensions

        self.weight = nn.Parameter(torch.empty(1, dim, in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(1, dim, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        # x has shape (B, D, 1)
        out = (self.weight * x.unsqueeze(-1)).sum(2)  # (B, D, F)
        if self.bias is not None:
            out += self.bias
        return out


class ElementwiseMLP(nn.Module):
    def __init__(self, dim, num_features=10):
        super(ElementwiseMLP, self).__init__()
        self.dim = dim
        self.num_features = num_features
        self.act_fun = nn.GELU()
        self.layer1 = ElementwiseLinear(dim, 1, num_features, bias=True)
        self.layer2 = ElementwiseLinear(dim, num_features, num_features, bias=True)
        self.layer3 = ElementwiseLinear(dim, num_features, 1, bias=False)

    def forward(self, x):
        # x has shape (B, D)
        x = self.act_fun(self.layer1(x.unsqueeze(-1)))  # (B, D, F)
        y = self.act_fun(x + self.layer2(x))  # (B, D, F)
        y = self.layer3(y).sum(dim=1)  # (B, 1)
        return y


class MLP(nn.Module):
    def __init__(self, dim, num_features=10):
        super(MLP, self).__init__()
        self.dim = dim
        self.num_features = num_features
        self.act_fun = nn.GELU()
        self.layer1 = nn.Linear(dim, num_features)
        self.layer2 = nn.Linear(num_features, num_features)
        self.layer3 = nn.Linear(num_features, 1, bias=False)

    def forward(self, x):
        # x has shape (B, D)
        x = self.act_fun(self.layer1(x))  # (B, F)
        y = self.act_fun(x + self.layer2(x))  # (B, F)
        y = self.layer3(y)  # (B, 1)
        return y


class DeltaLayer(nn.Module):
    """Compute $\delta(x,y)$

    Args:
        n: Number of latent dimensions (n).
        space_type: Either R^n (euclid) or S^{n-1} (sphere)
        p: Exponent(s) in the delta term (-1 means they are learnable parameters).
        tau: Rescaling parameter of the delta term.
        f1: function that depends only on the first component.
        f1: function that depends only on the second component.
        bias: Whether to include bias as learnable parameter
        margin_mode: The combination of marginals in the loss
    """
    def __init__(self, n, space_type='euclid', p=-1, tau=0.1, f1='zero', f2='zero', bias=False, margin_mode='second'):
        super().__init__()
        assert n >= 2
        assert space_type in ['sphere', 'euclid']
        assert (p == -1) or (p > 0)
        assert tau > 0
        assert f1 in ['zero', 'mlp', 'mlp_el']
        assert f2 in ['zero', 'mlp', 'mlp_el']
        assert margin_mode in ['first', 'second', 'both']

        self.f1_type = f1
        self.f2_type = f2
        self.bias = bias
        self.space_type = space_type
        self.tau = nn.Parameter(tau * torch.ones(1), requires_grad=False)
        # self.tau = nn.Parameter(tau * torch.ones(1), requires_grad=True)
        self.eps = 1e-8

        if p == -1:  # if exponents are learned, start from normal distribution
            self.p = nn.Parameter(2.0 * torch.ones(n), requires_grad=True)
        else:
            self.p = nn.Parameter(p * torch.ones(n), requires_grad=False)

        if self.bias:
            # TODO: remove if unnecessary
            if p < 2:
                self.c = nn.Parameter(torch.zeros(1), requires_grad=True)
                # self.c = nn.Parameter(-15 * torch.ones(1), requires_grad=True)
            else:
                self.c = nn.Parameter(torch.zeros(1), requires_grad=True)
        else:
            self.c = nn.Parameter(torch.zeros(1), requires_grad=False)

        self.a1 = nn.Parameter(0.01 * torch.randn(1), requires_grad=True)
        self.a2 = nn.Parameter(0.01 * torch.randn(1), requires_grad=True)

        num_features = 20
        if f1 == 'mlp':
            self.net1 = MLP(n, num_features)
        elif f1 == 'mlp_el':
            self.net1 = ElementwiseMLP(n, num_features)
        if f2 == 'mlp':
            self.net2 = MLP(n, num_features)
        elif f2 == 'mlp_el':
            self.net2 = ElementwiseMLP(n, num_features)
        
        self.margin_mode = margin_mode

        # TODO: delete
        self.pos1 = torch.zeros(1)
        self.pos1_min = torch.zeros(1)
        self.pos1_max = torch.zeros(1)
        self.neg1 = torch.zeros(1)
        self.neg1_min = torch.zeros(1)
        self.neg1_max = torch.zeros(1)
        self.pos2 = torch.zeros(1)
        self.pos2_min = torch.zeros(1)
        self.pos2_max = torch.zeros(1)
        self.neg2 = torch.zeros(1)
        self.neg2_min = torch.zeros(1)
        self.neg2_max = torch.zeros(1)

    @torch.no_grad()
    def get_param(self):
        param = {
            'a1': torch.abs(self.a1.detach().cpu()),
            'a2': torch.abs(self.a2.detach().cpu()),
            'c': self.c.detach().cpu(),
            'p': torch.abs(self.p.detach().cpu()),
            't': self.tau.detach().cpu(),
            'pos1': self.pos1.detach().cpu(),
            'pos1_min': self.pos1_min.detach().cpu(),
            'pos1_max': self.pos1_max.detach().cpu(),
            'neg1': self.neg1.detach().cpu(),
            'neg1_min': self.neg1_min.detach().cpu(),
            'neg1_max': self.neg1_max.detach().cpu(),
            'pos2': self.pos2.detach().cpu(),
            'pos2_min': self.pos2_min.detach().cpu(),
            'pos2_max': self.pos2_max.detach().cpu(),
            'neg2': self.neg2.detach().cpu(),
            'neg2_min': self.neg2_min.detach().cpu(),
            'neg2_max': self.neg2_max.detach().cpu(),
        }
        return param

    @torch.no_grad()
    def record_before(self, pos, neg=None):
        if neg is None:
            B = pos.shape[0]
            idx_diag = torch.eye(B, dtype=torch.bool, device=pos.device)
            neg1 = pos.masked_select(~idx_diag).view(B, B-1)
            pos1 = pos.masked_select(idx_diag).view(B)
        else:
            pos1 = pos.detach()
            neg1 = neg.detach()
        self.pos1 = torch.mean(pos1)
        self.pos1_min = torch.min(pos1)
        self.pos1_max = torch.max(pos1)
        self.neg1 = torch.mean(neg1)
        self.neg1_min = torch.min(neg1)
        self.neg1_max = torch.max(neg1)

    @torch.no_grad()
    def record_after(self, pos, neg):
        self.pos2 = torch.mean(pos.detach())
        self.pos2_min = torch.min(pos.detach())
        self.pos2_max = torch.max(pos.detach())
        self.neg2 = torch.mean(neg.detach())
        self.neg2_min = torch.min(neg.detach())
        self.neg2_max = torch.max(neg.detach())

    def f1(self, x):
        if self.f1_type == 'zero':
            f1 = torch.zeros(x.shape[0], device=x.device)
        elif self.f1_type == 'mlp' or self.f1_type == 'mlp_el':
            f1 = self.net1(x).squeeze(-1)
            f1 = torch.abs(self.a1) * (f1 - f1.mean())
        return f1
    
    def f2(self, x):
        if self.f2_type == 'zero':
            f2 = torch.zeros(x.shape[0], device=x.device)
        elif self.f2_type == 'mlp' or self.f2_type == 'mlp_el':
            f2 = self.net2(x).squeeze(-1)
            f2 = torch.abs(self.a2) * (f2 - f2.mean())
        return f2

    def sum_pow(self, delta):
        # avoid division by zero
        if torch.any(self.p < 1):
            delta = delta.clamp(min=self.eps)
        return (delta / self.tau).pow(self.p).sum(dim=-1)
    
    def margin_first(self, x, y, pos, neg):
        # f1
        f1x = self.f1(x)
        pos += f1x
        neg += f1x.unsqueeze(1)
        # f2
        pos += self.f2(y)
        neg += self.f2(x).unsqueeze(0)
        return pos, neg
    
    def margin_second(self, x, y, delta):
        # f1
        delta += self.f1(x).unsqueeze(1)
        # f2
        delta += self.f2(y).unsqueeze(0)
        # split pos and neg
        B = delta.shape[0]
        idx_diag = torch.eye(B, dtype=torch.bool, device=delta.device)
        pos = delta.masked_select(idx_diag).view(B)
        neg = delta.masked_select(~idx_diag).view(B, B-1)
        return pos, neg
    
    def margin_both(self, x, y, delta):
        # f1
        delta += self.f1(x).unsqueeze(1)
        # f2
        delta += self.f2(torch.cat([x, y], dim=0)).unsqueeze(0)
        # split pos and neg
        B = delta.shape[0]
        idx_diag = torch.cat([
            torch.zeros((B, B), dtype=torch.bool, device=delta.device),
            torch.eye(B, dtype=torch.bool, device=delta.device)
        ], dim=-1)
        pos = delta.masked_select(idx_diag).view(B)
        neg = delta.masked_select(~idx_diag).view(B, 2*B-1)
        return pos, neg

    def forward(self, x, y):
        if self.space_type == 'sphere':
            if self.margin_mode == 'first':
                pos = -torch.einsum("ij,ij -> i", x, y) / self.tau
                neg = -torch.einsum("ij,kj -> ik", x, x) / self.tau
            
                #TODO: delete
                self.record_before(pos, neg)

                pos, neg = self.margin_first(x, y, pos, neg)
            
            elif self.margin_mode == 'second':
                delta = -torch.einsum("ij,kj -> ik", x, y) / self.tau

                #TODO: delete
                self.record_before(delta)

                pos, neg = self.margin_second(x, y, delta)
            
            else:  # use average marginal (higher memory consumption)
                delta = -torch.einsum("ij,kj -> ik", x, torch.cat([x, y], dim=0)) / self.tau
                pos, neg = self.margin_both(x, y, delta)

        else:  # self.space_type == 'euclid'
            if self.margin_mode == 'first':
                pos = self.sum_pow(torch.abs(x - y))
                neg = self.sum_pow(torch.abs(x.unsqueeze(1) - x.unsqueeze(0)))
                
                #TODO: delete
                self.record_before(pos, neg)

                pos, neg = self.margin_first(x, y, pos, neg)
            
            elif self.margin_mode == 'second':
                delta = self.sum_pow(torch.abs(x.unsqueeze(1) - y.unsqueeze(0)))
                
                #TODO: delete
                self.record_before(delta)

                pos, neg = self.margin_second(x, y, delta)
            
            else:  # use average marginal (higher memory consumption)
                delta = self.sum_pow(torch.abs(x.unsqueeze(1) - torch.cat([x, y], dim=0).unsqueeze(0)))
                pos, neg = self.margin_both(x, y, delta)

        if self.bias:
            pos += self.c
            neg += self.c

        #TODO: delete
        self.record_after(pos, neg)

        return pos.unsqueeze(1), neg
