from typing import Optional, Dict

import torch 
from torch import nn 
from gpytorch.means import Mean
from gpytorch.kernels import Kernel


class Squareplus(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * (x + torch.sqrt(x**2 + 4))


class IdentityWrapper(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class MLPWrapper(nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()
        mlp = []
        for i in hidden_features:
            mlp.append(nn.Linear(in_features, i))
            mlp.append(nn.Tanh())
            in_features = i
        self.mlp = nn.Sequential(*mlp)

    def forward(self, x):
        return self.mlp(x)


class KumarWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.transform = nn.Softplus()
        # self.transform = Squareplus()
        self.alpha = nn.Parameter(torch.zeros(1))
        self.beta = nn.Parameter(torch.zeros(1))
        self.eps = 1e-6

    def forward(self, x):
        x = x.clip(self.eps, 1-self.eps)
        alpha = self.transform(self.alpha)
        beta = self.transform(self.beta)

        res = 1 - (1 - x.pow(alpha)).pow(beta)
        return res


def create_wrapper(wrapper, config: Optional[Dict]):
    if wrapper == 'identity':
        wrapper = IdentityWrapper()
    elif wrapper == 'kumar':
        wrapper = KumarWrapper()
    elif wrapper == 'mlp':
        wrapper = MLPWrapper(config['in_features'], config['hidden_features'])
    else:
        raise NotImplementedError
    return wrapper


class WrapperMean(Mean):
    def __init__(
        self,
        wrapper: nn.Module,
        final_layer: nn.Module
    ):
        super().__init__()
        self.wrapper = wrapper
        self.final_layer = final_layer

    def forward(self, x):
        m = self.final_layer(self.wrapper(x))
        return m.squeeze(-1)


class WrapperKernel(Kernel):
    def __init__(
        self,
        base_kernel: Kernel,
        wrapper: nn.Module,
        final_layer: nn.Module,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.base_kernel = base_kernel
        self.wrapper = wrapper
        self.final_layer = final_layer

    def forward(self, x1, x2, **params):
        x1 = self.final_layer(self.wrapper(x1))
        x2 = self.final_layer(self.wrapper(x2))
        return self.base_kernel(x1, x2, **params)