from src.models.causal_experiment_model import CausalExperimentModel
import torch
import torch.distributions as dist
import numpy as np
from src.models.graph_priors import ErdosRenyi, PermuteGenerate, ScaleFree
from src.models.utils import InverseGamma, compute_rff


class NonLinGaussANMModel(CausalExperimentModel):
    def __init__(
        self,
        d=2,
        n_parallel=1,
        graph_prior="erdos_renyi",
        graph_args={},
        noise_type="gauss",
        noise_args={},
        lengthscale=1.0,
        intervention_type="shift",
        num_rff=100,
    ):
        super().__init__()
        noise_types = {
            "gauss": dist.Normal,
            "laplace": dist.Laplace,
            "gumbel": dist.Gumbel,
        }
        graph_priors = {
            "erdos_renyi": ErdosRenyi,
            "scale_free": ScaleFree,
            "permute_generate": PermuteGenerate,
        }
        self.graph_prior_init = graph_priors[graph_prior]
        self.d = d
        self.var_dim = 2 * d
        self.n_parallel = n_parallel
        self.graph_args = graph_args
        self.lengthscale = lengthscale
        self.num_rff = num_rff
        self.noise_args = noise_args
        self.noise_dist = noise_types[noise_type]
        self.noise_type = noise_type
        if intervention_type == "shift":
            self.intervention_type = NonLinearANSEMShift
        else:
            self.intervention_type = NonLinearANSEM
        self.var_names = [
            "graph",
            "omega",
            "w",
            "noise_scales",
            "bias",
        ]  # , "noise_type"]
        self.reset(n_parallel)
        self.sanity_check()

    def sample_prior(self, num_theta, n_parallel=None, zero_bias=False):
        if n_parallel:
            self.reset(n_parallel)
        full_graph = self.graph_prior(num_theta)
        ls = dist.Uniform(low=0.1, high=self.ls).sample((num_theta,))
        omega = (
            dist.Normal(loc=0, scale=1.0 / ls)
            .sample((self.num_rff, self.d))
            .transpose(0, 2)
        )
        w = dist.Normal(loc=0, scale=1.0).sample(
            (num_theta, self.n_parallel, 1, self.d, self.num_rff)
        )
        b = dist.Uniform(low=0.0, high=2 * np.pi).sample(
            (num_theta, self.n_parallel, 1, self.d, self.num_rff)
        )
        noise_scales = torch.sqrt(
            InverseGamma(
                concentration=self.noise_scale_concentration, rate=self.noise_scale_rate
            ).sample((num_theta,))
        )
        bias_flag = torch.zeros(noise_scales.shape[:-1])
        if zero_bias is None:
            bias_flag = torch.randint(2, size=noise_scales.shape[:-1])
        if not zero_bias:
            bias_flag = torch.ones(noise_scales.shape[:-1])
        bias_flag = bias_flag.to(torch.bool)
        bias_uniform = dist.Uniform(low=-1.0, high=1.0).sample(noise_scales.shape)
        bias = torch.zeros_like(noise_scales)
        bias[bias_flag] = bias_uniform[bias_flag]
        # noise_type = torch.randint(3, size=noise_scales.shape[:-1])
        return {
            "graph": full_graph,
            "omega": omega,
            "w": w,
            "noise_scales": noise_scales,
            "bias": bias,
            "b": b,
            # "noise_type": noise_type,
        }

    def rsample(self, design, theta, n_samples=1):
        graph = theta["graph"]
        omega = theta["omega"]
        w = theta["w"]
        noise_scales = theta["noise_scales"]
        bias = theta["bias"]
        b = theta["b"]
        return self.intervention_type(
            graph,
            w,
            omega,
            b,
            self.noise_dist(loc=bias, scale=noise_scales),
            (design[..., : self.d] > 0).to(design.dtype),
            design[..., self.d :],
            self.num_rff,
        ).rsample((n_samples,))

    def log_prob(self, y, design, theta):
        graph = theta["graph"]
        omega = theta["omega"]
        w = theta["w"]
        noise_scales = theta["noise_scales"]
        bias = theta["bias"]
        b = theta["b"]
        return self.intervention_type(
            graph,
            w,
            omega,
            b,
            self.noise_dist(loc=bias, scale=noise_scales),
            (design[..., : self.d] > 0).to(design.dtype),
            design[..., self.d :],
            self.num_rff,
        ).log_prob(y)

    def reset(self, n_parallel):
        self.n_parallel = n_parallel
        self.graph_prior = self.graph_prior_init(
            **{**self.graph_args, "n_parallel": n_parallel, "num_nodes": self.d}
        )
        self.noise_scale_concentration = 10.0 * torch.ones(n_parallel, 1, self.d)
        self.noise_scale_rate = torch.ones(n_parallel, 1, self.d)
        self.ls = self.lengthscale * torch.ones(n_parallel, 1, self.d)


class NonLinearANSEM(object):
    """
    NonLinear Additive Noise Structural Equation Model based on RFF features.
    It assumes the underlying function is a linear combination of RFF features with additive noise and f_j ~ GP(0, k(x_pa,x'_pa)).
    """

    has_rsample = True

    def __init__(
        self,
        graph,
        w,
        omega,
        b,
        exogenous_noise_dist,
        intervention_mask,
        intervention_values,
        num_rff=100,
    ):
        self.exogenous_noise_dist = exogenous_noise_dist
        self.graph = graph
        self.w = w
        self.omega = omega
        self.b = b
        if len(self.omega.shape) == 5:
            self.omega = self.omega.unsqueeze(0)
            self.w = self.w.unsqueeze(0)
            self.b = self.b.unsqueeze(0)
        self.num_nodes = self.graph.shape[-1]
        self.intervention_mask = intervention_mask
        self.intervention_values = intervention_values
        self.num_rff = num_rff

    def rsample(self, sample_shape=torch.Size([])):
        z = self.exogenous_noise_dist.rsample(sample_shape).squeeze(-2).transpose(0, 1)
        sample = torch.ones_like(z) * self.intervention_mask * self.intervention_values
        for i in range(self.num_nodes):
            sample = (
                compute_rff(
                    graph=self.graph, omega=self.omega, y=sample, w=self.w, b=self.b
                )
                + z
            )
            sample = (
                self.intervention_mask * self.intervention_values
                + (1 - self.intervention_mask) * sample
            )
        return sample

    def log_prob(self, y):
        predict = self.intervention_mask * self.intervention_values + (
            1 - self.intervention_mask
        ) * compute_rff(graph=self.graph, omega=self.omega, y=y, w=self.w, b=self.b)
        z = y - predict
        log_prob = self.exogenous_noise_dist.log_prob(z)
        return (log_prob * (1 - self.intervention_mask)).sum(-1)


class NonLinearANSEMShift(NonLinearANSEM):
    """
    NonLinear Additive Noise Structural Equation Model based on RFF features.
    It assumes the underlying function is a linear combination of RFF features with additive noise and f_j ~ GP(0, k(x_pa,x'_pa)).
    """

    has_rsample = True

    def __init__(
        self,
        graph,
        w,
        omega,
        b,
        exogenous_noise_dist,
        intervention_mask,
        intervention_values,
        num_rff=100,
    ):
        super().__init__(
            graph,
            w,
            omega,
            b,
            exogenous_noise_dist,
            intervention_mask,
            intervention_values,
            num_rff,
        )

    def rsample(self, sample_shape=torch.Size([])):
        z = self.exogenous_noise_dist.rsample(sample_shape).squeeze(-2).transpose(0, 1)
        z = z + self.intervention_mask * self.intervention_values
        sample = torch.ones_like(z) * self.intervention_mask * self.intervention_values
        for i in range(self.num_nodes):
            sample = (
                compute_rff(
                    graph=self.graph, omega=self.omega, y=sample, w=self.w, b=self.b
                )
                + z
            )
        return sample

    def log_prob(self, y):
        predict = compute_rff(
            graph=self.graph, omega=self.omega, y=y, w=self.w, b=self.b
        )
        z = y - predict
        z = z - self.intervention_mask * self.intervention_values
        log_prob = self.exogenous_noise_dist.log_prob(z)
        return log_prob.sum(-1)


class NonLinGaussANMModelHeteroskedastic(NonLinGaussANMModel):
    def __init__(
        self,
        d=2,
        n_parallel=1,
        graph_prior="erdos_renyi",
        graph_args={},
        noise_type="gauss",
        noise_args={},
        lengthscale=1.0,
        intervention_type="shift",
        num_rff=100,
    ):
        super().__init__(
            d=d,
            n_parallel=n_parallel,
            graph_prior=graph_prior,
            graph_args=graph_args,
            noise_type=noise_type,
            noise_args=noise_args,
            lengthscale=lengthscale,
            intervention_type=intervention_type,
            num_rff=num_rff,
        )

        if intervention_type == "shift":
            self.intervention_type = NonLinearANSEMShiftHk
        else:
            self.intervention_type = NonLinearANSEMHk
        self.var_names = [
            "graph",
            "omega",
            "w",
            "b",
            "bias",
            "omeega_noise",
            "w_noise",
            "b_noise",
        ]
        self.reset(n_parallel)
        self.sanity_check()

    def sample_prior(self, num_theta, n_parallel=None, zero_bias=False):
        if n_parallel:
            self.reset(n_parallel)
        full_graph = self.graph_prior(num_theta)
        ls = dist.Uniform(low=0.1, high=self.ls).sample((num_theta,))
        omega = (
            dist.Normal(loc=0, scale=1.0 / ls)
            .sample((self.num_rff, self.d))
            .transpose(0, 2)
        )
        omega_noise = dist.Normal(loc=0, scale=0.1).sample(omega.shape)
        w = dist.Normal(loc=0, scale=1.0).sample(
            (num_theta, self.n_parallel, 1, self.d, self.num_rff)
        )
        w_noise = dist.Normal(loc=0, scale=0.1).sample(w.shape)
        b = dist.Uniform(low=0.0, high=2 * np.pi).sample(
            (num_theta, self.n_parallel, 1, self.d, self.num_rff)
        )
        b_noise = dist.Normal(loc=0, scale=0.1).sample(b.shape)
        bias_shape = (num_theta, self.n_parallel, 1, self.d)
        bias_flag = torch.zeros(bias_shape[:-1])
        if zero_bias is None:
            bias_flag = torch.randint(2, size=bias_shape[:-1])
        if not zero_bias:
            bias_flag = torch.ones(bias_shape[:-1])
        bias_flag = bias_flag.to(torch.bool)
        bias_uniform = dist.Uniform(low=-1.0, high=1.0).sample(bias_shape)
        bias = torch.zeros(bias_shape)
        bias[bias_flag] = bias_uniform[bias_flag]
        return {
            "graph": full_graph,
            "omega": omega,
            "w": w,
            "bias": bias,
            "b": b,
            "omega_noise": omega_noise,
            "w_noise": w_noise,
            "b_noise": b_noise,
            # "noise_type": noise_type,
        }

    def rsample(self, design, theta, n_samples=1):
        graph = theta["graph"]
        omega = theta["omega"]
        w = theta["w"]
        bias = theta["bias"]
        b = theta["b"]
        omega_noise = theta["omega_noise"]
        w_noise = theta["w_noise"]
        b_noise = theta["b_noise"]
        return self.intervention_type(
            graph,
            w,
            omega,
            b,
            w_noise,
            omega_noise,
            b_noise,
            self.noise_dist(loc=bias, scale=1.0),
            (design[..., : self.d] > 0).to(design.dtype),
            design[..., self.d :],
            self.num_rff,
        ).rsample((n_samples,))

    def log_prob(self, y, design, theta):
        pass

    def reset(self, n_parallel):
        self.n_parallel = n_parallel
        self.graph_prior = self.graph_prior_init(
            **{**self.graph_args, "n_parallel": n_parallel, "num_nodes": self.d}
        )
        self.ls = self.lengthscale * torch.ones(n_parallel, 1, self.d)


class NonLinearANSEMHk(object):
    """
    NonLinear Additive Noise Structural Equation Model based on RFF features.
    It assumes the underlying function is a linear combination of RFF features with additive noise and f_j ~ GP(0, k(x_pa,x'_pa)).
    """

    has_rsample = True

    def __init__(
        self,
        graph,
        w,
        omega,
        b,
        w_noise,
        omega_noise,
        b_noise,
        exogenous_noise_dist,
        intervention_mask,
        intervention_values,
        num_rff=100,
    ):
        self.exogenous_noise_dist = exogenous_noise_dist
        self.graph = graph
        self.w = w
        self.omega = omega
        self.b = b
        self.w_noise = w_noise
        self.omega_noise = omega_noise
        self.b_noise = b_noise
        if len(self.omega.shape) == 5:
            self.omega = self.omega.unsqueeze(0)
            self.w = self.w.unsqueeze(0)
            self.b = self.b.unsqueeze(0)
            self.omega_noise = self.omega_noise.unsqueeze(0)
            self.w_noise = self.w_noise.unsqueeze(0)
            self.b_noise = self.b_noise.unsqueeze(0)
        self.num_nodes = self.graph.shape[-1]
        self.intervention_mask = intervention_mask
        self.intervention_values = intervention_values
        self.num_rff = num_rff

    def rsample(self, sample_shape=torch.Size([])):
        z = self.exogenous_noise_dist.rsample(sample_shape).squeeze(-2).transpose(0, 1)
        sample = torch.ones_like(z) * self.intervention_mask * self.intervention_values
        for i in range(self.num_nodes):
            f = compute_rff(
                omega=self.omega_noise,
                y=sample,
                graph=self.graph,
                b=self.b_noise,
                w=self.w_noise,
                c=2.0,
            )
            f = torch.sqrt(torch.log(1 + torch.exp(f)))
            sample = (
                compute_rff(
                    graph=self.graph, omega=self.omega, y=sample, w=self.w, b=self.b
                )
                + f * z
            )
            sample = (
                self.intervention_mask * self.intervention_values
                + (1 - self.intervention_mask) * sample
            )
        return sample

    def log_prob(self, y):
        pass


class NonLinearANSEMShiftHk(NonLinearANSEMHk):
    """
    NonLinear Additive Noise Structural Equation Model based on RFF features.
    It assumes the underlying function is a linear combination of RFF features with additive noise and f_j ~ GP(0, k(x_pa,x'_pa)).
    """

    has_rsample = True

    def __init__(
        self,
        graph,
        w,
        omega,
        b,
        w_noise,
        omega_noise,
        b_noise,
        exogenous_noise_dist,
        intervention_mask,
        intervention_values,
        num_rff=100,
    ):
        super().__init__(
            graph,
            w,
            omega,
            b,
            w_noise,
            omega_noise,
            b_noise,
            exogenous_noise_dist,
            intervention_mask,
            intervention_values,
            num_rff,
        )

    def rsample(self, sample_shape=torch.Size([])):
        z = self.exogenous_noise_dist.rsample(sample_shape).squeeze(-2).transpose(0, 1)
        z = z + self.intervention_mask * self.intervention_values
        sample = torch.ones_like(z) * self.intervention_mask * self.intervention_values
        for i in range(self.num_nodes):
            f = compute_rff(
                omega=self.omega_noise,
                y=sample,
                graph=self.graph,
                b=self.b_noise,
                w=self.w_noise,
                c=2.0,
            )
            f = torch.sqrt(torch.log(1 + torch.exp(f)))
            sample = (
                compute_rff(
                    graph=self.graph, omega=self.omega, y=sample, w=self.w, b=self.b
                )
                + z * f
            )
        return sample

    def log_prob(self, y):
        pass
