import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F


class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, N, sigma):
        self.y = torch.rand(N) * 10 - 5
        noise = torch.randn(N) * sigma
        self.x = torch.sin(self.y) + noise

    def plot(self):
        plt.scatter(self.x, self.y)
        plt.show()

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class HyperDiffusion(torch.nn.Module):
    def __init__(self, T, embed_dim, device, deactivate=0):
        super(HyperDiffusion, self).__init__()
        self.device = device
        self.T = T
        self.deactivate = deactivate

        self.embed_dim = embed_dim
        self.hypernetwork = torch.nn.Sequential(
            torch.nn.Linear(embed_dim, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 5345),
        ).to(device)
        self.layer_shapes = [
            [(16, 3), (16,)],  # (w0, b0)
            [(32, 16), (32,)],  # (w1, b1)
            [(64, 32), (64,)],  # (w2, b2)
            [(32, 64), (32,)],  # (w3, b3)
            [(16, 32), (16,)],  # (w4, b4)
            [(1, 16), (1,)],  # (w5, b5)
        ]

        self.alphas = 1.0 - torch.linspace(0.001, 0.2, T).to(self.device)
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(self.device)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
            1 - self.sqrt_alphas_cumprod**2
        ).to(self.device)
        self.alphas_cumprod_prev = torch.nn.functional.pad(
            self.alphas_cumprod[:-1], (1, 0), value=1.0
        ).to(self.device)
        self.posterior_variance = (
            (1 - self.alphas)
            * (1.0 - self.alphas_cumprod_prev)
            / (1.0 - self.alphas_cumprod).to(self.device)
        )

    def network(self, x, t, condition, weight):
        outputs = []
        # for i, weight in enumerate(weights):
        w0 = weight[:48].reshape(self.layer_shapes[0][0])
        b0 = weight[48:64].reshape(self.layer_shapes[0][1])
        w1 = weight[64:576].reshape(self.layer_shapes[1][0])
        b1 = weight[576:608].reshape(self.layer_shapes[1][1])
        w2 = weight[608:2656].reshape(self.layer_shapes[2][0])
        b2 = weight[2656:2720].reshape(self.layer_shapes[2][1])
        w3 = weight[2720:4768].reshape(self.layer_shapes[3][0])
        b3 = weight[4768:4800].reshape(self.layer_shapes[3][1])
        w4 = weight[4800:5312].reshape(self.layer_shapes[4][0])
        b4 = weight[5312:5328].reshape(self.layer_shapes[4][1])
        w5 = weight[5328:5344].reshape(self.layer_shapes[5][0])
        b5 = weight[5344:5345].reshape(self.layer_shapes[5][1])
        # out = torch.stack([x[i], t[i], condition[i]])
        out = torch.stack([x, t, condition]).transpose(1, 0)
        out = F.relu(F.linear(out, w0, b0))
        if self.deactivate < 2:
            out = F.relu(F.linear(out, w1, b1))
        if self.deactivate < 1:
            out = F.relu(F.linear(out, w2, b2))
        if self.deactivate < 1:
            out = F.relu(F.linear(out, w3, b3))
        if self.deactivate < 2:
            out = F.relu(F.linear(out, w4, b4))
        out = F.linear(out, w5, b5)
        outputs.append(out)
        return torch.stack(outputs)

    def q_sample(self, x_0, t, noise):
        """
        Sample x at time t given the value of x at t=0 and the noise
        """
        return self.sqrt_alphas_cumprod[t] * x_0.to(
            self.device
        ) + self.sqrt_one_minus_alphas_cumprod[t] * noise.to(self.device)

    def p_loss(self, x, t, cond, embedding):
        # Generate a noise
        noise = torch.randn(x.shape).to(self.device)
        # Compute x at time t with this value of the noise - forward process
        noisy_x = self.q_sample(x, t, noise)
        # Use our trained model to predict the value of the noise, given x(t) and t
        weights = self.hypernetwork(embedding.to(self.device))
        noise_computed = self.network(
            noisy_x, t.to(self.device), cond.to(self.device), weights
        ).flatten()
        # Compare predicted value of the noise with the actual value
        return torch.nn.functional.mse_loss(noise, noise_computed)

    def p_sample(self, x, t, cond, embedding):
        """
        One step of revese process sampling - Algorithm 2 from the paper
        """
        alpha_t = self.alphas.gather(-1, t)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(
            -1, t
        )

        with torch.no_grad():
            weights = self.hypernetwork(embedding.to(self.device))
            model_out = self.network(x, t, cond, weights).squeeze()

        model_mean = torch.sqrt(1.0 / alpha_t).to(self.device) * (
            x - (1 - alpha_t) * model_out / sqrt_one_minus_alphas_cumprod_t
        )
        model_var = self.posterior_variance.gather(-1, t)

        # Samples from a normal distribution with given mean and variance
        return model_mean + torch.sqrt(model_var).to(self.device) * torch.randn_like(
            model_var
        ).to(self.device)

    def p_sample_loop(self, num_samples, cond, embedding):
        """
        Sample x at time t given the value of x at t=0 and the noise
        """
        x_gen = torch.randn(num_samples).to(self.device)

        for t in range(self.T - 1, 0, -1):
            t = torch.tensor(t).repeat(num_samples).long().to(self.device)
            x_gen = self.p_sample(x_gen, t, cond.to(self.device), embedding)
        return x_gen


class HyperNetwork(torch.nn.Module):
    def __init__(self, embed_dim, device):
        super(HyperNetwork, self).__init__()
        self.device = device
        self.model = torch.nn.Sequential(
            torch.nn.Linear(embed_dim, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 141),
        ).to(device)
        self.layer_shapes = [
            [(10, 1), (10,)],  # (w0, b0)
            [(10, 10), (10,)],  # (w1, b1)
            [(1, 10), (1,)],  # (w2, b2)
        ]

    def forward(self, y, embedding):
        weights = self.model(embedding.to(self.device))
        return self.network(y.unsqueeze(-1), weights).squeeze(-1)

    def network(self, x, weights):
        outputs = []
        for i, weight in enumerate(weights):
            w0 = weight[:10].reshape(self.layer_shapes[0][0])
            b0 = weight[10:20].reshape(self.layer_shapes[0][1])
            w1 = weight[20:120].reshape(self.layer_shapes[1][0])
            b1 = weight[120:130].reshape(self.layer_shapes[1][1])
            w2 = weight[130:140].reshape(self.layer_shapes[2][0])
            b2 = weight[140:141].reshape(self.layer_shapes[2][1])
            out = F.relu(F.linear(x[i], w0, b0))
            out = F.relu(F.linear(out, w1, b1))
            out = F.linear(out, w2, b2)
            outputs.append(out)
        return torch.stack(outputs)


class MonteCarloDropout(torch.nn.Module):
    def __init__(self, channels=3):
        super(MonteCarloDropout, self).__init__()
        self.layers = torch.nn.ParameterList(
            [
                torch.nn.Linear(channels, 16),
                torch.nn.Linear(16, 32),
                torch.nn.Linear(32, 64),
                torch.nn.Linear(64, 32),
                torch.nn.Linear(32, 16),
                torch.nn.Linear(16, 1),
            ]
        )

    def forward(self, x, p=0.3, gen=None):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # Activation function for all layers except the last one
            if not i == len(self.layers) - 1:
                x = F.relu(x)
                mask = torch.bernoulli(torch.ones_like(x) * (1 - p), generator=gen)
                scale = 1.0 / (1 - p)
                x = x * mask * scale
        return x


# Standard diffusion model (no hyper-network)
class DropoutDDPM(torch.nn.Module):
    def __init__(self, T, p, device):
        super(DropoutDDPM, self).__init__()
        self.device = device
        self.T = T
        self.p = p
        self.rng = torch.Generator(device=self.device)

        self.model = MonteCarloDropout(channels=3).to(device)
        self.alphas = 1.0 - torch.linspace(0.001, 0.2, T).to(self.device)
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(self.device)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
            1 - self.sqrt_alphas_cumprod**2
        ).to(self.device)
        self.alphas_cumprod_prev = torch.nn.functional.pad(
            self.alphas_cumprod[:-1], (1, 0), value=1.0
        ).to(self.device)
        self.posterior_variance = (
            (1 - self.alphas)
            * (1.0 - self.alphas_cumprod_prev)
            / (1.0 - self.alphas_cumprod).to(self.device)
        )

    def forward(self, x, t, cond, gen):
        out = torch.stack([x, t, cond]).transpose(1, 0)
        return self.model(out, self.p, gen)

    def q_sample(self, x_0, t, noise):
        """
        Sample x at time t given the value of x at t=0 and the noise
        """
        return self.sqrt_alphas_cumprod[t] * x_0.to(
            self.device
        ) + self.sqrt_one_minus_alphas_cumprod[t] * noise.to(self.device)

    def p_loss(self, x, t, cond):
        # Generate a noise
        noise = torch.randn(x.shape).to(self.device)
        # Compute x at time t with this value of the noise - forward process
        noisy_x = self.q_sample(x, t, noise)
        # Use our trained model to predict the value of the noise, given x(t) and t
        noise_computed = self.forward(
            noisy_x, t.to(self.device), cond.to(self.device), self.rng
        ).flatten()
        # Compare predicted value of the noise with the actual value
        return torch.nn.functional.mse_loss(noise, noise_computed)

    def p_sample(self, x, t, cond, gen):
        """
        One step of revese process sampling - Algorithm 2 from the paper
        """
        alpha_t = self.alphas.gather(-1, t)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(
            -1, t
        )

        with torch.no_grad():
            model_out = self.forward(x, t, cond, gen).squeeze()

        model_mean = torch.sqrt(1.0 / alpha_t).to(self.device) * (
            x - (1 - alpha_t) * model_out / sqrt_one_minus_alphas_cumprod_t
        )
        model_var = self.posterior_variance.gather(-1, t)

        # Samples from a normal distribution with given mean and variance
        noise = torch.randn_like(model_var)
        return model_mean + torch.sqrt(model_var).to(self.device) * noise.to(
            self.device
        )

    def p_sample_loop(self, num_samples, cond, seed):
        """
        Sample x at time t given the value of x at t=0 and the noise
        """
        x_gen = torch.randn(num_samples)
        x_gen = x_gen.to(self.device)

        seeded_rng = self.rng.manual_seed(seed)
        for t in range(self.T - 1, 0, -1):
            t = torch.tensor(t).repeat(num_samples).long().to(self.device)
            x_gen = self.p_sample(x_gen, t, cond.to(self.device), seeded_rng)
        return x_gen


# Standard diffusion model (no hyper-network)
class DDPM(torch.nn.Module):
    def __init__(self, T, device):
        super(DDPM, self).__init__()
        self.device = device
        self.T = T

        self.model = torch.nn.Sequential(
            torch.nn.Linear(3, 16),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.1),
            torch.nn.Linear(16, 1),
        ).to(device)
        self.alphas = 1.0 - torch.linspace(0.001, 0.2, T).to(self.device)
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0).to(self.device)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(self.device)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
            1 - self.sqrt_alphas_cumprod**2
        ).to(self.device)
        self.alphas_cumprod_prev = torch.nn.functional.pad(
            self.alphas_cumprod[:-1], (1, 0), value=1.0
        ).to(self.device)
        self.posterior_variance = (
            (1 - self.alphas)
            * (1.0 - self.alphas_cumprod_prev)
            / (1.0 - self.alphas_cumprod).to(self.device)
        )

    def forward(self, x, t, cond):
        out = torch.stack([x, t, cond]).transpose(1, 0)
        return self.model(out)

    def q_sample(self, x_0, t, noise):
        """
        Sample x at time t given the value of x at t=0 and the noise
        """
        return self.sqrt_alphas_cumprod[t] * x_0.to(
            self.device
        ) + self.sqrt_one_minus_alphas_cumprod[t] * noise.to(self.device)

    def p_loss(self, x, t, cond):
        # Generate a noise
        noise = torch.randn(x.shape).to(self.device)
        # Compute x at time t with this value of the noise - forward process
        noisy_x = self.q_sample(x, t, noise)
        # Use our trained model to predict the value of the noise, given x(t) and t
        noise_computed = self.forward(
            noisy_x, t.to(self.device), cond.to(self.device)
        ).flatten()
        # Compare predicted value of the noise with the actual value
        return torch.nn.functional.mse_loss(noise, noise_computed)

    def p_sample(self, x, t, cond):
        """
        One step of revese process sampling - Algorithm 2 from the paper
        """
        alpha_t = self.alphas.gather(-1, t)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(
            -1, t
        )

        with torch.no_grad():
            model_out = self.forward(x, t, cond).squeeze()

        model_mean = torch.sqrt(1.0 / alpha_t).to(self.device) * (
            x - (1 - alpha_t) * model_out / sqrt_one_minus_alphas_cumprod_t
        )
        model_var = self.posterior_variance.gather(-1, t)

        # Samples from a normal distribution with given mean and variance
        noise = torch.randn_like(model_var)
        return model_mean + torch.sqrt(model_var).to(self.device) * noise.to(
            self.device
        )

    def p_sample_loop(self, num_samples, cond):
        """
        Sample x at time t given the value of x at t=0 and the noise
        """
        x_gen = torch.randn(num_samples)
        x_gen = x_gen.to(self.device)

        for t in range(self.T - 1, 0, -1):
            t = torch.tensor(t).repeat(num_samples).long().to(self.device)
            x_gen = self.p_sample(x_gen, t, cond.to(self.device))
        return x_gen
