import math
from typing import Union

import numpy as np
import torch
import torch.distributions as D
import torch.nn as nn
from torch.nn import functional as F


def check_rq_spline_input(
        inputs: torch.Tensor,
        knots_x: torch.Tensor,
        knots_y: torch.Tensor,
        knots_logd: torch.Tensor,
        bias: Union[torch.Tensor, float]
):
    """
    Checks if input to rational quadratic spline forward and inverse functions is valid.

    :param inputs: tensor with shape (n_inputs, n_dim).
    :param knots_x: tensor with shape (n_inputs, n_dim, n_knots).
    :param knots_y: tensor with shape (n_inputs, n_dim, n_knots).
    :param knots_logd: tensor with shape (n_inputs, n_dim, n_knots).
    :param bias: tensor with shape (n_inputs, n_dim).
    """

    assert len(knots_x.shape) == 3
    n_inputs, n_dim, _ = knots_x.shape
    assert knots_x.shape == knots_y.shape == knots_logd.shape
    if type(bias) == torch.Tensor:
        n_inputs = inputs.shape[0]
        assert bias.shape == inputs.shape == (n_inputs, n_dim)
    elif type(bias) != float:
        raise ValueError
    assert inputs.dtype == knots_x.dtype == knots_y.dtype == knots_logd.dtype


def check_rq_spline_output(
        outputs: torch.Tensor,
        logj: torch.Tensor
):
    """
    Checks if output from rational quadratic spline forward and inverse functions is valid.
    """

    assert len(outputs.shape) == 2
    assert len(logj.shape) == 1
    assert outputs.shape[0] == logj.shape[0]


def rational_quadratic_spline_forward(
        inputs: torch.Tensor,
        knots_x: torch.Tensor,
        knots_y: torch.Tensor,
        knots_logd: torch.Tensor,
        bias: Union[torch.Tensor, float] = 0.0
):
    """
    Compute the forward value of a rational quadratic spline on inputs x for a single dimension.
    The spline is determined by knots (x coordinates, y coordinates, log derivatives).
    The transformation also includes a bias.

    :param inputs: tensor with shape (n_inputs, n_dim).
    :param knots_x: tensor with shape (n_inputs, n_dim, n_knots).
    :param knots_y: tensor with shape (n_inputs, n_dim, n_knots).
    :param knots_logd: tensor with shape (n_inputs, n_dim, n_knots).
    :param bias: tensor with shape (n_inputs, n_dim).
    """
    check_rq_spline_input(inputs, knots_x, knots_y, knots_logd, bias)

    knots_d = torch.exp(knots_logd)

    # Make the array where the output value and logj will be held
    outputs = torch.zeros_like(inputs)
    logj = torch.zeros_like(inputs)

    index = torch.searchsorted(knots_x, inputs.unsqueeze(2).contiguous()).squeeze(2)
    n_dim = inputs.shape[-1]
    n_knots = knots_d.shape[-1]

    # Linear extrapolation (left)
    select0 = (index == 0)  # Outputs will be updated at these indices
    dim = torch.repeat_interleave(torch.arange(n_dim).view(1, n_dim), len(inputs), dim=0)[select0]
    outputs[select0] = knots_y[dim, 0, 0] + (inputs[select0] - knots_x[dim, 0, 0]) * knots_d[dim, 0, 0]
    logj[select0] = knots_logd[dim, 0, 0]

    # Linear extrapolation (right)
    selectn = (index == n_knots)
    dim = torch.repeat_interleave(torch.arange(n_dim).view(1, n_dim), len(inputs), dim=0)[selectn]
    outputs[selectn] = knots_y[dim, 0, -1] + (inputs[selectn] - knots_x[dim, 0, -1]) * knots_d[dim, 0, -1]
    logj[selectn] = knots_logd[dim, 0, -1]

    # Rational quadratic spline
    select = ~(select0 | selectn)  # Outputs will be updated at these indices
    index = index[select]
    dim = torch.repeat_interleave(torch.arange(n_dim).view(1, n_dim), len(inputs), dim=0)[select]
    xi = (inputs[select] - knots_x[dim, 0, index - 1]) / (knots_x[dim, 0, index] - knots_x[dim, 0, index - 1])
    s = (knots_y[dim, 0, index] - knots_y[dim, 0, index - 1]) / (knots_x[dim, 0, index] - knots_x[dim, 0, index - 1])
    xi1_xi = xi * (1 - xi)
    denominator = s + (knots_d[dim, 0, index] + knots_d[dim, 0, index - 1] - 2 * s) * xi1_xi
    xi2 = xi ** 2

    outputs[select] = knots_y[dim, 0, index - 1] + (
            (knots_y[dim, 0, index] - knots_y[dim, 0, index - 1]) * (s * xi2 + knots_d[dim, 0, index - 1] * xi1_xi)
    ) / denominator
    logj[select] = 2 * torch.log(s) + torch.log(
        knots_d[dim, 0, index] * xi2 + 2 * s * xi1_xi + knots_d[dim, 0, index - 1] * (1 - xi) ** 2
    ) - 2 * torch.log(denominator)

    outputs += bias
    logj = torch.sum(logj, dim=1)

    check_rq_spline_output(outputs, logj)

    return outputs, logj


def rational_quadratic_spline_inverse(
        inputs: torch.Tensor,
        knots_x: torch.Tensor,
        knots_y: torch.Tensor,
        knots_logd: torch.Tensor,
        bias: Union[torch.Tensor, float] = 0.0
):
    """
    Compute the inverse value of a rational quadratic spline on inputs x for a single dimension.
    The spline is determined by knots (x coordinates, y coordinates, log derivatives).
    The transformation also includes a bias.

    :param inputs: tensor with shape (n_inputs, n_dim).
    :param knots_x: tensor with shape (n_inputs, n_dim, n_knots).
    :param knots_y: tensor with shape (n_inputs, n_dim, n_knots).
    :param knots_logd: tensor with shape (n_inputs, n_dim, n_knots).
    :param bias: tensor with shape (n_inputs, n_dim).
    """
    check_rq_spline_input(inputs, knots_x, knots_y, knots_logd, bias)

    knots_d = torch.exp(knots_logd)
    inputs -= bias

    # Make the array where the output value and logj will be held
    outputs = torch.zeros_like(inputs)
    logj = torch.zeros_like(inputs)

    index = torch.searchsorted(knots_y, inputs.unsqueeze(2).contiguous()).squeeze(2)
    n_dim = inputs.shape[-1]
    n_knots = knots_d.shape[-1]

    # Linear extrapolation (left)
    select0 = (index == 0)
    dim = torch.repeat_interleave(torch.arange(n_dim).view(1, n_dim), len(inputs), dim=0)[select0]
    outputs[select0] = knots_x[dim, 0, 0] + (inputs[select0] - knots_y[dim, 0, 0]) / knots_d[dim, 0, 0]
    logj[select0] = knots_logd[dim, 0, 0]

    # Linear extrapolation (right)
    selectn = (index == n_knots)
    dim = torch.repeat_interleave(torch.arange(n_dim).view(1, n_dim), len(inputs), dim=0)[selectn]
    outputs[selectn] = knots_x[dim, 0, -1] + (inputs[selectn] - knots_y[dim, 0, -1]) / knots_d[dim, 0, -1]
    logj[selectn] = knots_logd[dim, 0, -1]

    # Rational quadratic spline
    select = ~(select0 | selectn)
    index = index[select]
    dim = torch.repeat_interleave(torch.arange(n_dim).view(1, n_dim), len(inputs), dim=0)[select]
    delta_y = knots_y[dim, 0, index] - knots_y[dim, 0, index - 1]
    s = delta_y / (knots_x[dim, 0, index] - knots_x[dim, 0, index - 1])
    delta_2s = knots_d[dim, 0, index] + knots_d[dim, 0, index - 1] - 2 * s
    delta_y_delta_2s = (inputs[select] - knots_y[dim, 0, index - 1]) * delta_2s

    a = delta_y * (s - knots_d[dim, 0, index - 1]) + delta_y_delta_2s
    b = delta_y * knots_d[dim, 0, index - 1] - delta_y_delta_2s
    c = - s * (inputs[select] - knots_y[dim, 0, index - 1])
    discriminant = b.pow(2) - 4 * a * c
    # discriminant[discriminant < 0] = 0  # This is a hack, should be removed and the problem addressed
    assert (discriminant >= 0).all()
    xi = - 2 * c / (b + torch.sqrt(discriminant))
    xi1_xi = xi * (1 - xi)

    outputs[select] = xi * (knots_x[dim, 0, index] - knots_x[dim, 0, index - 1]) + knots_x[dim, 0, index - 1]
    logj[select] = 2 * torch.log(s) + torch.log(
        knots_d[dim, 0, index] * xi ** 2 + 2 * s * xi1_xi + knots_d[dim, 0, index - 1] * (1 - xi) ** 2) - 2 * torch.log(
        s + delta_2s * xi1_xi)

    logj = -torch.sum(logj, dim=1)

    check_rq_spline_output(outputs, logj)

    return outputs, logj


class RQSpline(nn.Module):
    def __init__(self, n_dim: int, n_knots=3):
        super().__init__()

        self.n_knots = n_knots
        self.n_dim = n_dim

        w0 = 25.0

        self.x0 = nn.Parameter(torch.tensor([-w0 / 2]))
        self.log_width = nn.Parameter(torch.tensor([math.log(w0)]))  # This should be initialized according to data
        self.y0 = nn.Parameter(torch.tensor([-w0 / 2]))
        self.log_height = nn.Parameter(torch.tensor([math.log(w0)]))  # This should be initialized according to data

        self.base_logw = nn.Parameter(torch.zeros(n_dim, n_knots))
        self.base_logh = nn.Parameter(torch.zeros(n_dim, n_knots))
        self.base_logd = nn.Parameter(torch.zeros(n_dim, n_knots - 2))

    def regularization(self):
        # return self.log_height + self.log_width
        # return self.log_height + self.log_width + self.base_logd.sum()
        return self.log_width + self.base_logd.square().sum()
        # return self.base_logd.sum()
        # return torch.tensor(0.0)

    @property
    def width(self):
        return torch.exp(self.log_width)

    @property
    def height(self):
        return torch.exp(self.log_height)

    def get_parameters(self, *args):
        inputs, = args
        n_inputs, _ = inputs.shape

        logh = self.base_logh.repeat(n_inputs, 1, 1)
        logw = self.base_logw.repeat(n_inputs, 1, 1)
        padded_logd = F.pad(self.base_logd, (1, 1, 0, 0))
        logd = padded_logd.repeat(n_inputs, 1, 1)
        # logd = -F.softplus(logd)  # Prevent positive values
        bias = torch.zeros_like(inputs)

        widths = F.softmax(logw, dim=-1) * self.width
        x = self.x0 + torch.cumsum(widths, dim=-1)

        heights = F.softmax(logh, dim=-1) * self.height
        y = self.y0 + torch.cumsum(heights, dim=-1)

        return x, y, logd, bias

    def forward(self, *args):
        # Prepare spline parameters for each input
        knots_x, knots_y, knots_logd, bias = self.get_parameters(*args)

        # Perform the spline forward
        outputs, logj = rational_quadratic_spline_forward(args[0], knots_x, knots_y, knots_logd, bias)
        return outputs, logj

    def inverse(self, *args):
        # Prepare spline parameters for each input
        knots_x, knots_y, knots_logd, bias = self.get_parameters(*args)

        outputs, logj = rational_quadratic_spline_inverse(args[0], knots_x, knots_y, knots_logd, bias)
        return outputs, logj


class CustomDropout(nn.Module):
    def __init__(self, drop_prob: float = 0.25):
        super().__init__()
        self.mask = None
        assert 0 <= drop_prob <= 1
        self.p = drop_prob

    def forward(self, x):
        if self.mask is None:
            self.mask = (torch.rand_like(x, device=x.device) > self.p).type(x.dtype)
        return x * self.mask

    def reset_mask(self):
        self.mask = None


class ConditionalRQSpline(RQSpline):
    def __init__(self, n_dim, n_dim_cond, n_knots=3, n_hidden=100, n_layers=2):
        super().__init__(n_dim=n_dim, n_knots=n_knots)
        assert n_layers >= 2
        self.n_dim = n_dim

        n_parameters = n_dim  # + (n_knots - 2)
        self.conditional_network = nn.Linear(n_dim_cond, n_parameters)
        # self.conditional_network = nn.Sequential(
        #     nn.Linear(n_dim_cond, n_hidden),
        #     nn.Sigmoid(),
        #     CustomDropout(0.25),
        #     nn.Linear(n_hidden, n_parameters)
        # )

    def get_parameters(self, *args):
        conditional_inputs = args[1]
        n_inputs = len(conditional_inputs)

        bias = self.conditional_network(conditional_inputs)  # .repeat(n_inputs, 1, 1).transpose(2, 1)
        logh = self.base_logh.repeat(n_inputs, 1, 1)
        logw = self.base_logw.repeat(n_inputs, 1, 1)
        padded_logd = F.pad(self.base_logd, (1, 1, 0, 0))
        logd = padded_logd.repeat(n_inputs, 1, 1)
        # logd = -F.softplus(logd)  # Prevent positive values

        widths = F.softmax(logw, dim=-1) * self.width
        x = self.x0 + torch.cumsum(widths, dim=-1)

        heights = F.softmax(logh, dim=-1) * self.height
        y = self.y0 + torch.cumsum(heights, dim=-1)

        return x, y, logd, bias

    def reset_dropout_masks(self):
        for module in self.conditional_network.modules():
            if isinstance(module, CustomDropout):
                module.reset_mask()


class SplineFlow(nn.Module):
    def __init__(self, n_dim, n_splines=5, n_knots=3):
        super().__init__()
        self.n_dim = n_dim
        self.splines = nn.ModuleList([RQSpline(n_dim=n_dim, n_knots=n_knots) for _ in range(n_splines)])

    def regularization(self):
        total = 0.0
        for spline in self.splines:
            total += spline.regularization()
        return total / len(self.splines)

    def forward(self, x):
        logj = torch.zeros(len(x), device=x.device)
        for spline in self.splines:
            x, logj0 = spline.forward(x)
            logj += logj0
        return x, logj

    def inverse(self, z):
        logj = torch.zeros(len(z), device=z.device)
        for spline in self.splines[::-1]:
            z, logj0 = spline.inverse(z)
            logj += logj0
        return z, -logj

    def log_prob(self, x):
        z, logj = self.forward(x)
        logp = D.Normal(0, 1).log_prob(z).sum(dim=1)
        log_prob = logj + logp
        return log_prob

    def sample(self, n):
        z = D.Normal(0.0, 1.0).sample((n, self.n_dim))
        x = self.inverse(z)[0]
        return x


class ConditionalSplineFlow(nn.Module):
    def __init__(self, n_dim, n_dim_cond, n_splines=5, n_knots=3):
        super().__init__()
        self.n_dim = n_dim
        self.n_dim_cond = n_dim_cond
        self.splines = nn.ModuleList([
            ConditionalRQSpline(
                n_dim=n_dim, n_dim_cond=n_dim_cond, n_knots=n_knots
            ) for _ in range(n_splines)
        ])

    def reset_dropout_masks(self):
        for spline in self.splines:
            if isinstance(spline, ConditionalRQSpline):
                spline.reset_dropout_masks()

    def regularization(self):
        total = 0.0
        for spline in self.splines:
            total += spline.regularization()
        return total

    def forward(self, x, conditional_inputs):
        logj = torch.zeros(len(x), device=x.device)
        for spline in self.splines:
            x, logj0 = spline.forward(x, conditional_inputs)
            logj += logj0
        return x, logj

    def inverse(self, z, conditional_inputs):
        logj = torch.zeros(len(z), device=z.device)
        for spline in self.splines[::-1]:
            z, logj0 = spline.inverse(z, conditional_inputs)
            logj += logj0
        return z, -logj

    def log_prob(self, x, conditional_inputs):
        z, logj = self.forward(x, conditional_inputs)
        logp = D.Normal(0, 1).log_prob(z).sum(dim=1)
        log_prob = logj + logp
        return log_prob

    def sample(self, conditional_inputs):
        n = len(conditional_inputs)
        z = D.Normal(0.0, 1.0).sample((n, self.n_dim))
        x = self.inverse(z, conditional_inputs)[0]
        return x


if __name__ == '__main__':
    # Reconstruction accuracy test
    torch.manual_seed(0)

    logd_offset = 0

    n_train = 1000
    x_tensor = torch.linspace(-8, 8, n_train).view(n_train, 1)
    knots_x_tensor = torch.arange(-5, 6).float().repeat(n_train, 1).unsqueeze(1)
    knots_y_tensor = torch.arange(-5, 6).float().repeat(n_train, 1).unsqueeze(1)
    knots_logd_tensor = torch.rand(knots_x_tensor.shape[2]).float().repeat(n_train, 1).unsqueeze(1) + logd_offset
    knots_logd_tensor[..., 0] = 0
    knots_logd_tensor[..., -1] = 0


    @torch.no_grad()
    def compute_reconstruction_error(scaling: float = 1.0):
        z_tensor, logj_forward = rational_quadratic_spline_forward(
            x_tensor,
            knots_x_tensor,
            knots_y_tensor,
            knots_logd_tensor * scaling
        )
        x_reconstructed, logj_inverse = rational_quadratic_spline_inverse(
            z_tensor,
            knots_x_tensor,
            knots_y_tensor,
            knots_logd_tensor * scaling
        )

        mae = float(torch.nanmean((x_tensor - x_reconstructed).abs()))
        max_norm = float(torch.max((x_tensor - x_reconstructed).abs()))
        logj_forward = float(torch.nanmean(logj_forward)),
        logj_inverse = float(torch.nanmean(logj_inverse))

        return mae, max_norm, logj_forward, logj_inverse


    scales = []
    mae_list = []
    max_norm_list = []
    logj_forward_list = []
    logj_inverse_list = []
    for scale in np.arange(-30, 30, 3):
        try:
            mae, max_norm, logj_forward, logj_inverse = compute_reconstruction_error(scale)
            if not np.isinf(mae) and not np.isinf(max_norm):
                mae_list.append(mae)
                max_norm_list.append(max_norm)
                logj_forward_list.append(logj_forward)
                logj_inverse_list.append(logj_inverse)
                scales.append(scale)
        except (ValueError, AssertionError) as e:
            pass

    import matplotlib.pyplot as plt

    plt.plot(scales, mae_list, '-', label='MAE')
    plt.plot(scales, max_norm_list, '-', label='Max norm')
    # plt.plot(scales, logj_forward_list, '-', label='logj forward')
    plt.plot(scales, logj_inverse_list, '-', label='logj inverse')
    plt.xlabel('logd scale')
    plt.ylabel('Error')
    plt.yscale('log')
    plt.legend()
    plt.show()

    scaling = 1.0
    for scaling in scales:
        z_tensor, logj_forward = rational_quadratic_spline_forward(
            x_tensor,
            knots_x_tensor,
            knots_y_tensor,
            knots_logd_tensor * scaling
        )
        x_reconstructed, logj_inverse = rational_quadratic_spline_inverse(
            z_tensor,
            knots_x_tensor,
            knots_y_tensor,
            knots_logd_tensor * scaling
        )

        errors = torch.abs(x_tensor - x_reconstructed).flatten().detach()
        sizes = errors / (errors.min() + 1e-2)
        sizes *= 100
        print(float(knots_logd_tensor.min() * scaling), float(knots_logd_tensor.max() * scaling))
        c0 = 'tab:blue'
        c1 = 'tab:orange'
        plt.plot(x_tensor.flatten(), z_tensor.flatten(), label='Forward spline', color=c0)
        plt.scatter(x_tensor.flatten(), z_tensor.flatten(), color=c0, s=sizes.numpy())
        plt.plot(x_reconstructed.flatten(), z_tensor.flatten(), label='Forward spline (reconstruction)', color=c1)
        # plt.scatter(x_reconstructed.flatten(), z_tensor.flatten(), color=c1, s=errors.detach().numpy())
        plt.legend()
        plt.show()
