import sys

import math
from collections import namedtuple
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from denoising_diffusion_pytorch.attend import Attend
from einops import rearrange, reduce, repeat
from torch.cuda.amp import autocast
from tqdm import tqdm

ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"])


def identity(t, *args, **kwargs):
    return t


def normalize_to_neg_one_to_one(img):
    return img * 2 - 1


def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d


def upsample(x, params, dim, dim_out):
    weight = params[: dim * dim_out * 3 * 3].reshape(dim_out, dim, 3, 3)
    bias = params[dim * dim_out * 3 * 3 :]
    x = F.interpolate(x, scale_factor=2, mode="nearest")
    x = F.conv2d(x, weight, bias, padding=1)
    return x


def downsample(x, params, dim, dim_out=None):
    x = rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2)
    i = 0
    n_params = dim * 4 * default(dim_out, dim) * 1 * 1
    conv_weight = params[i : i + n_params].reshape(default(dim_out, dim), dim * 4, 1, 1)
    i += n_params
    n_params = default(dim_out, dim)
    conv_bias = params[i : i + n_params]
    i += n_params
    return F.conv2d(x, conv_weight, bias=conv_bias)


def rms_norm(x, dim, weights):
    weights = weights.reshape(1, dim, 1, 1)
    return F.normalize(x, dim=1) * weights * (x.shape[1] ** 0.5)


def sin_pos_embed(x, dim, device):
    theta = 10000
    half_dim = dim // 2
    emb = math.log(theta) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
    emb = x[:, None] * emb[None, :]
    emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
    return emb


def block(x, params, dim, dim_out, groups=8, scale_shift=None):
    conv_bias = params[3 * -dim_out : -2 * dim_out]
    conv_weight = params[: 3 * -dim_out]
    conv_weight = conv_weight.reshape(dim_out, dim, 3, 3)

    norm_weight = params[2 * -dim_out : -dim_out]
    norm_bias = params[-dim_out:]
    # arbitary in/out, 3 kernel
    x = F.conv2d(x, conv_weight, conv_bias, padding=1)
    x = F.group_norm(x, groups, norm_weight, norm_bias)

    if scale_shift is not None:
        scale, shift = scale_shift
        x = x * (scale + 1) + shift

    return F.silu(x)


def resnetblock(x, params, time_emb, dim, dim_out, time_emb_dim, groups=8):
    time_emb = F.silu(time_emb)

    i = 0

    # time linear layer
    n_params = 2 * dim_out * time_emb_dim
    time_weight = params[i : i + n_params]
    time_weight = time_weight.reshape(2 * dim_out, time_emb_dim)
    i += n_params

    n_params = 2 * dim_out
    time_bias = params[i : i + n_params]
    i += n_params

    time_emb = F.linear(time_emb, time_weight, time_bias)
    time_emb = rearrange(time_emb, "b c -> b c 1 1")
    scale_shift = time_emb.chunk(2, dim=1)

    # conv block
    n_params = dim_out * dim * 9 + 3 * dim_out
    block1_params = params[i : i + n_params]
    i += n_params
    h = block(x, block1_params, dim, dim_out, scale_shift=scale_shift, groups=groups)

    # conv block 2
    n_params = dim_out * dim_out * 9 + 3 * dim_out
    block2_params = params[i : i + n_params]
    i += n_params

    h = block(h, block2_params, dim_out, dim_out, groups=groups)

    if dim != dim_out:
        n_params = dim_out * dim
        conv_weight = params[i : i + n_params]
        conv_weight = conv_weight.reshape(dim_out, dim, 1, 1)
        i += n_params
        n_params = dim_out
        conv_bias = params[i : i + n_params]
        x = F.conv2d(x, conv_weight, conv_bias)
        i += n_params

    return h + x


def linear_attention(x, params, dim, heads=4, dim_head=32, num_mem_kv=4):
    b, c, h, w = x.shape
    hidden_dim = dim_head * heads

    # load parameter weights first
    i = 0
    n_params = 2 * heads * dim_head * num_mem_kv
    mem_kv = params[i : i + n_params]
    mem_kv = mem_kv.reshape(2, heads, dim_head, num_mem_kv)
    i += n_params

    # rms norm layer
    n_params = dim
    norm_weights = params[i : i + n_params]
    norm_weights = norm_weights.reshape(1, dim, 1, 1)
    i += n_params
    x = rms_norm(x, dim, norm_weights)

    # conv2d
    n_params = dim * hidden_dim * 3
    conv_weight = params[i : i + n_params]
    conv_weight = conv_weight.reshape(hidden_dim * 3, dim, 1, 1)
    i += n_params

    qkv = F.conv2d(x, conv_weight).chunk(3, dim=1)
    q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=heads), qkv)

    mk, mv = map(lambda t: repeat(t, "h c n -> b h c n", b=b), mem_kv)
    k, v = map(partial(torch.cat, dim=-1), ((mk, k), (mv, v)))

    q = q.softmax(dim=-2)
    k = k.softmax(dim=-1)

    q = q * (dim_head**-0.5)

    context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

    out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
    out = rearrange(out, "b h c (x y) -> b (h c) x y", h=heads, x=h, y=w)

    # out layer
    n_params = hidden_dim * dim
    out_weight = params[i : i + n_params]
    out_weight = out_weight.reshape(dim, hidden_dim, 1, 1)
    i += n_params

    n_params = dim
    out_bias = params[i : i + n_params]
    i += n_params

    out = F.conv2d(out, out_weight, out_bias)

    # rms norm layer
    n_params = dim
    norm_weights = params[i : i + n_params]
    norm_weights = norm_weights.reshape(1, dim, 1, 1)
    i += n_params

    return rms_norm(out, dim, norm_weights)


def attention(x, params, dim, heads=4, dim_head=32, num_mem_kv=4, flash=False):
    b, c, h, w = x.shape
    hidden_dim = dim_head * heads
    attend = Attend(flash=flash)

    i = 0
    # preload parameter weights
    n_params = 2 * heads * dim_head * num_mem_kv
    mem_kv = params[i : i + n_params]
    mem_kv = mem_kv.reshape(2, heads, num_mem_kv, dim_head)
    i += n_params

    # rms norm layer
    n_params = dim
    norm_weights = params[i : i + n_params]
    norm_weights = norm_weights.reshape(1, dim, 1, 1)
    i += n_params

    x = rms_norm(x, dim, norm_weights)

    # conv2d
    n_params = dim * hidden_dim * 3
    conv_weight = params[i : i + n_params]
    conv_weight = conv_weight.reshape(hidden_dim * 3, dim, 1, 1)
    i += n_params

    qkv = F.conv2d(x, conv_weight).chunk(3, dim=1)

    q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=heads), qkv)

    # linear

    mk, mv = map(lambda t: repeat(t, "h n d -> b h n d", b=b), mem_kv)
    k, v = map(partial(torch.cat, dim=-2), ((mk, k), (mv, v)))

    out = attend(q, k, v)

    out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)

    # out layer
    n_params = hidden_dim * dim
    out_weight = params[i : i + n_params]
    out_weight = out_weight.reshape(dim, hidden_dim, 1, 1)
    i += n_params

    n_params = dim
    out_bias = params[i : i + n_params]
    i += n_params

    return F.conv2d(out, out_weight, out_bias)


class FrozenUnet(nn.Module):
    def __init__(
        self,
        dim,
        device,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        self_condition=True,
        resnet_block_groups=8,
        learned_variance=False,
        learned_sinusoidal_cond=False,
        random_fourier_features=False,
        learned_sinusoidal_dim=16,
        sinusoidal_pos_emb_theta=10000,
        attn_dim_head=32,
        attn_heads=4,
        full_attn=None,  # defaults to full attention only for inner most layer
        flash_attn=False,
    ):
        super().__init__()
        self.dim = dim
        self.device = device
        self.init_dim = init_dim
        self.dim_mults = dim_mults
        self.channels = channels
        self.self_condition = self_condition
        self.resnet_block_groups = resnet_block_groups
        self.learned_variance = learned_variance
        self.learned_sinusoidal_cond = learned_sinusoidal_cond
        self.random_fourier_features = random_fourier_features
        self.learned_sinusoidal_dim = learned_sinusoidal_dim
        self.sinusoidal_pos_emb_theta = sinusoidal_pos_emb_theta
        self.attn_dim_head = attn_dim_head
        self.attn_heads = attn_heads
        self.full_attn = full_attn
        self.flash_attn = flash_attn

        self.n_blocks = len(dim_mults)

        default_out_dim = channels * (1 if not learned_variance else 2)
        self.out_dim = default(out_dim, default_out_dim)

        self.random_or_learned_sinusoidal_cond = (
            learned_sinusoidal_cond or random_fourier_features
        )

    def forward(self, x, t, cond, params):
        input_channels = self.channels * (2 if self.self_condition else 1)
        init_dim = default(self.init_dim, self.dim)
        time_dim = self.dim * 4
        dims = [init_dim, *map(lambda m: self.dim * m, self.dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        # append condition
        x = torch.cat((cond, x), dim=1)

        # input conv layer
        i = 0
        n_params = input_channels * self.dim * 7 * 7
        conv_weight = params[i : i + n_params].reshape(init_dim, input_channels, 7, 7)
        i += n_params

        n_params = init_dim
        conv_bias = params[i : i + n_params]
        i += n_params

        x = F.conv2d(x, conv_weight, conv_bias, padding=3)
        r = x.clone()

        # time embed
        t = sin_pos_embed(t, self.dim, self.device)
        n_params = self.dim * time_dim
        linear_weight = params[i : i + n_params].reshape(time_dim, self.dim)
        i += n_params
        n_params = time_dim
        linear_bias = params[i : i + n_params]
        i += n_params
        t = F.linear(t, linear_weight, linear_bias)
        t = F.gelu(t)
        n_params = time_dim * time_dim
        linear_weight = params[i : i + n_params].reshape(time_dim, time_dim)
        i += n_params
        n_params = time_dim
        linear_bias = params[i : i + n_params]
        i += n_params
        t = F.linear(t, linear_weight, linear_bias)

        h = []

        for j, (dim_in, dim_out) in enumerate(in_out):
            # resnet block 1
            n_params = (
                dim_in * 2 * time_dim
                + dim_in * 2
                + 6 * dim_in
                + dim_in * dim_in * 9
                + dim_in * dim_in * 9
            )

            block_params = params[i : i + n_params]
            x = resnetblock(x, block_params, t, dim_in, dim_in, time_dim)
            h.append(x)
            i += n_params

            # resnet block 2
            block_params = params[i : i + n_params]
            x = resnetblock(x, block_params, t, dim_in, dim_in, time_dim)
            h.append(x)
            i += n_params

            # downsample / conv
            if j < self.n_blocks - 1:
                # linear attention
                n_params = (
                    8 * self.attn_heads * self.attn_dim_head
                    + dim_in * (self.attn_dim_head * self.attn_heads) * 3
                    + (self.attn_dim_head * self.attn_heads) * dim_in
                    + 3 * dim_in
                )
                attn_params = params[i : i + n_params]
                x = (
                    linear_attention(
                        x,
                        attn_params,
                        dim_in,
                        heads=self.attn_heads,
                        dim_head=self.attn_dim_head,
                    )
                    + x
                )
                i += n_params

                # downsample
                n_params = dim_in * 4 * dim_out + dim_out
                downsample_params = params[i : i + n_params]
                x = downsample(x, downsample_params, dim_in, dim_out)
                i += n_params
            else:
                # full attention block
                n_params = (
                    8 * self.attn_heads * self.attn_dim_head
                    + dim_in * (self.attn_dim_head * self.attn_heads) * 3
                    + (self.attn_dim_head * self.attn_heads) * dim_in
                    + dim_in
                    + dim_in
                )
                attn_params = params[i : i + n_params]
                x = (
                    attention(
                        x,
                        attn_params,
                        dim_in,
                        heads=self.attn_heads,
                        dim_head=self.attn_dim_head,
                    )
                    + x
                )
                i += n_params

                # conv layer
                n_params = dim_in * dim_out * 9
                conv_weight = params[i : i + n_params].reshape(dim_out, dim_in, 3, 3)
                i += n_params
                n_params = dim_out
                conv_bias = params[i : i + n_params]
                i += n_params
                x = F.conv2d(x, conv_weight, conv_bias, padding=1)

        # resnet block
        mid_dim = dims[-1]
        n_params = (
            mid_dim * 2 * time_dim
            + mid_dim * 2
            + 6 * mid_dim
            + mid_dim * mid_dim * 9
            + mid_dim * mid_dim * 9
        )
        block_params = params[i : i + n_params]
        x = resnetblock(x, block_params, t, mid_dim, mid_dim, time_dim)
        i += n_params

        # full attention block
        n_params = (
            8 * self.attn_heads * self.attn_dim_head
            + mid_dim * (self.attn_dim_head * self.attn_heads) * 3
            + (self.attn_dim_head * self.attn_heads) * mid_dim
            + mid_dim
            + mid_dim
        )
        attn_params = params[i : i + n_params]
        x = (
            attention(
                x,
                attn_params,
                mid_dim,
                heads=self.attn_heads,
                dim_head=self.attn_dim_head,
            )
            + x
        )
        i += n_params

        # resnet block
        n_params = (
            mid_dim * 2 * time_dim
            + mid_dim * 2
            + 6 * mid_dim
            + mid_dim * mid_dim * 9
            + mid_dim * mid_dim * 9
        )
        block_params = params[i : i + n_params]
        x = resnetblock(x, block_params, t, mid_dim, mid_dim, time_dim)
        i += n_params

        for j, (dim_in, dim_out) in enumerate(reversed(in_out)):
            # resnet block 1
            x = torch.cat((x, h.pop()), dim=1)
            cat_dim_in = dim_in + dim_out
            n_params = (
                dim_out * 2 * time_dim
                + dim_out * 2
                + 7 * dim_out
                + dim_out * dim_out * 9
                + cat_dim_in * dim_out * 9
                + cat_dim_in * dim_out
            )
            block_params = params[i : i + n_params]
            x = resnetblock(x, block_params, t, dim_in + dim_out, dim_out, time_dim)
            i += n_params

            # resnet block 2
            x = torch.cat((x, h.pop()), dim=1)
            block_params = params[i : i + n_params]
            x = resnetblock(x, block_params, t, dim_in + dim_out, dim_out, time_dim)
            i += n_params

            # full attention
            if j == 0:
                n_params = (
                    8 * self.attn_heads * self.attn_dim_head
                    + dim_out * (self.attn_dim_head * self.attn_heads) * 3
                    + (self.attn_dim_head * self.attn_heads) * dim_out
                    + dim_out
                    + dim_out
                )
                attn_params = params[i : i + n_params]
                x = (
                    attention(
                        x,
                        attn_params,
                        dim_out,
                        heads=self.attn_heads,
                        dim_head=self.attn_dim_head,
                    )
                    + x
                )
                i += n_params
            else:
                # linear attention
                n_params = 2 * 4 * 32 * 4 + dim_out * 384 + dim_out * 128 + 3 * dim_out
                attn_params = params[i : i + n_params]
                x = (
                    linear_attention(
                        x,
                        attn_params,
                        dim_out,
                        heads=self.attn_heads,
                        dim_head=self.attn_dim_head,
                    )
                    + x
                )
                i += n_params

            # downsample / conv
            if j < self.n_blocks - 1:
                n_params = dim_in * dim_out * 9 + dim_in
                upsample_params = params[i : i + n_params]
                x = upsample(x, upsample_params, dim_out, dim_in)
                i += n_params
            else:
                n_params = dim_in * dim_out * 3 * 3
                conv_weight = params[i : i + n_params].reshape(dim_in, dim_out, 3, 3)
                i += n_params
                n_params = dim_in
                conv_bias = params[i : i + n_params]
                i += n_params
                x = F.conv2d(x, conv_weight, conv_bias, padding=1)
        x = torch.cat((x, r), dim=1)

        # resnet block
        dim_in = self.dim * 2
        dim_out = self.dim
        n_params = (
            dim_out * 2 * time_dim
            + dim_out * 2
            + 7 * dim_out
            + dim_out * dim_out * 9
            + dim_in * dim_out * 9
            + dim_in * dim_out
        )
        block_params = params[i : i + n_params]
        x = resnetblock(x, block_params, t, dim_in, dim_out, time_dim)
        i += n_params

        # conv layer
        dim_out = self.channels * (1 if not self.learned_variance else 2)
        n_params = self.dim * dim_out
        conv_weight = params[i : i + n_params].reshape(dim_out, self.dim, 1, 1)
        i += n_params
        n_params = dim_out
        conv_bias = params[i : i + n_params]
        i += n_params
        x = F.conv2d(x, conv_weight, conv_bias)

        return x


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original ddpm paper
    """
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)


def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
    """
    sigmoid schedule
    proposed in https://arxiv.org/abs/2212.11972 - Figure 8
    better for images > 64x64, when used during training
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
    v_start = torch.tensor(start / tau).sigmoid()
    v_end = torch.tensor(end / tau).sigmoid()
    alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (
        v_end - v_start
    )
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class HyperDiffusion(nn.Module):
    def __init__(
        self,
        model,
        *,
        image_size,
        in_dim=8,
        out_dim=64,
        n_params=35716289,
        timesteps=1000,
        sampling_timesteps=None,
        objective="pred_v",
        beta_schedule="sigmoid",
        schedule_fn_kwargs=dict(),
        ddim_sampling_eta=0.0,
        auto_normalize=True,
        offset_noise_strength=0.0,  # https://www.crosslabs.org/blog/diffusion-with-offset-noise
        min_snr_loss_weight=False,  # https://arxiv.org/abs/2303.09556
        min_snr_gamma=5,
    ):
        super().__init__()
        assert not (type(self) == HyperDiffusion and model.channels != model.out_dim)
        assert not model.random_or_learned_sinusoidal_cond

        self.hypernetwork = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, n_params),
        )
        # initialize network weights to zero
        # for p in self.hypernetwork.parameters():
        # p.data.zero_()
        self.n_params = n_params
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.model = model

        self.channels = self.model.channels
        self.self_condition = self.model.self_condition

        self.image_size = image_size

        self.objective = objective

        assert objective in {
            "pred_noise",
            "pred_x0",
            "pred_v",
        }, "objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])"

        if beta_schedule == "linear":
            beta_schedule_fn = linear_beta_schedule
        elif beta_schedule == "cosine":
            beta_schedule_fn = cosine_beta_schedule
        elif beta_schedule == "sigmoid":
            beta_schedule_fn = sigmoid_beta_schedule
        else:
            raise ValueError(f"unknown beta schedule {beta_schedule}")

        betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)

        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

        (timesteps,) = betas.shape
        self.num_timesteps = int(timesteps)

        # sampling related parameters

        self.sampling_timesteps = default(
            sampling_timesteps, timesteps
        )  # default num sampling timesteps to number of timesteps at training

        assert self.sampling_timesteps <= timesteps
        self.is_ddim_sampling = self.sampling_timesteps < timesteps
        self.ddim_sampling_eta = ddim_sampling_eta

        # helper function to register buffer from float64 to float32

        register_buffer = lambda name, val: self.register_buffer(
            name, val.to(torch.float32)
        )

        register_buffer("betas", betas)
        register_buffer("alphas_cumprod", alphas_cumprod)
        register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others

        register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        register_buffer(
            "sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
        )
        register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
        register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
        register_buffer(
            "sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
        )

        # calculations for posterior q(x_{t-1} | x_t, x_0)

        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)

        register_buffer("posterior_variance", posterior_variance)

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain

        register_buffer(
            "posterior_log_variance_clipped",
            torch.log(posterior_variance.clamp(min=1e-20)),
        )
        register_buffer(
            "posterior_mean_coef1",
            betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
        )
        register_buffer(
            "posterior_mean_coef2",
            (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
        )

        # offset noise strength - in blogpost, they claimed 0.1 was ideal

        self.offset_noise_strength = offset_noise_strength

        # derive loss weight
        # snr - signal noise ratio

        snr = alphas_cumprod / (1 - alphas_cumprod)

        # https://arxiv.org/abs/2303.09556

        maybe_clipped_snr = snr.clone()
        if min_snr_loss_weight:
            maybe_clipped_snr.clamp_(max=min_snr_gamma)

        if objective == "pred_noise":
            register_buffer("loss_weight", maybe_clipped_snr / snr)
        elif objective == "pred_x0":
            register_buffer("loss_weight", maybe_clipped_snr)
        elif objective == "pred_v":
            register_buffer("loss_weight", maybe_clipped_snr / (snr + 1))

        # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False

        self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
        self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity

    @property
    def device(self):
        return self.betas.device

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
        ) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def predict_v(self, x_start, t, noise):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
            - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
        )

    def predict_start_from_v(self, x_t, t, v):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
            - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def model_predictions(
        self, x, t, x_self_cond, in_vec, clip_x_start=False, rederive_pred_noise=False
    ):
        params = self.hypernetwork(in_vec)
        model_output = self.model(x, t, x_self_cond, params)
        maybe_clip = (
            partial(torch.clamp, min=-1.0, max=1.0) if clip_x_start else identity
        )

        if self.objective == "pred_noise":
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, pred_noise)
            x_start = maybe_clip(x_start)

            if clip_x_start and rederive_pred_noise:
                pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == "pred_x0":
            x_start = model_output
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == "pred_v":
            v = model_output
            x_start = self.predict_start_from_v(x, t, v)
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        return ModelPrediction(pred_noise, x_start)

    def p_mean_variance(self, x, t, x_self_cond, in_vec, clip_denoised=True):
        preds = self.model_predictions(x, t, x_self_cond, in_vec)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1.0, 1.0)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
            x_start=x_start, x_t=x, t=t
        )
        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.inference_mode()
    def p_sample(self, x, t: int, x_self_cond, in_vec):
        b, *_, device = *x.shape, self.device
        batched_times = torch.full((b,), t, device=device, dtype=torch.long)
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(
            x=x,
            t=batched_times,
            x_self_cond=x_self_cond,
            in_vec=in_vec,
            clip_denoised=True,
        )
        noise = torch.randn_like(x) if t > 0 else 0.0  # no noise if t == 0
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred_img, x_start

    @torch.inference_mode()
    def p_sample_loop(self, shape, conds, in_vec, return_all_timesteps=False):
        batch, device = shape[0], self.device

        img = torch.randn(shape, device=device)
        imgs = [img]

        x_start = None

        for t in reversed(range(0, self.num_timesteps)):
            # self_cond = x_start if self.self_condition else None
            img, x_start = self.p_sample(img, t, conds, in_vec)
            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim=1)

        ret = self.unnormalize(ret)
        return ret

    @torch.inference_mode()
    def ddim_sample(self, shape, return_all_timesteps=False):
        batch, device, total_timesteps, sampling_timesteps, eta, objective = (
            shape[0],
            self.device,
            self.num_timesteps,
            self.sampling_timesteps,
            self.ddim_sampling_eta,
            self.objective,
        )

        times = torch.linspace(
            -1, total_timesteps - 1, steps=sampling_timesteps + 1
        )  # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist()))
        time_pairs = list(
            zip(times[:-1], times[1:])
        )  # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        img = torch.randn(shape, device=device)
        imgs = [img]

        x_start = None

        for time, time_next in tqdm(time_pairs, desc="sampling loop time step"):
            time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
            self_cond = x_start if self.self_condition else None
            pred_noise, x_start, *_ = self.model_predictions(
                img, time_cond, self_cond, clip_x_start=True, rederive_pred_noise=True
            )

            if time_next < 0:
                img = x_start
                imgs.append(img)
                continue

            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]

            sigma = (
                eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            )
            c = (1 - alpha_next - sigma**2).sqrt()

            noise = torch.randn_like(img)

            img = x_start * alpha_next.sqrt() + c * pred_noise + sigma * noise

            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim=1)

        ret = self.unnormalize(ret)
        return ret

    @torch.inference_mode()
    def sample(self, batch_size=16, return_all_timesteps=False):
        image_size, channels = self.image_size, self.channels
        sample_fn = (
            self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
        )
        return sample_fn(
            (batch_size, channels, image_size, image_size),
            return_all_timesteps=return_all_timesteps,
        )

    @torch.inference_mode()
    def interpolate(self, x1, x2, t=None, lam=0.5):
        b, *_, device = *x1.shape, x1.device
        t = default(t, self.num_timesteps - 1)

        assert x1.shape == x2.shape

        t_batched = torch.full((b,), t, device=device)
        xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))

        img = (1 - lam) * xt1 + lam * xt2

        x_start = None

        for i in tqdm(
            reversed(range(0, t)), desc="interpolation sample time step", total=t
        ):
            self_cond = x_start if self.self_condition else None
            img, x_start = self.p_sample(img, i, self_cond)

        return img

    @autocast(enabled=False)
    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

    def p_losses(
        self, x_start, t, cond, in_vec, noise=None, offset_noise_strength=None
    ):
        b, c, h, w = x_start.shape

        noise = default(noise, lambda: torch.randn_like(x_start))

        # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise

        offset_noise_strength = default(
            offset_noise_strength, self.offset_noise_strength
        )

        if offset_noise_strength > 0.0:
            offset_noise = torch.randn(x_start.shape[:2], device=self.device)
            noise += offset_noise_strength * rearrange(offset_noise, "b c -> b c 1 1")

        # noise sample

        x = self.q_sample(x_start=x_start, t=t, noise=noise)

        params = self.hypernetwork(in_vec)

        model_out = self.model(x, t, cond, params)

        if self.objective == "pred_noise":
            target = noise
        elif self.objective == "pred_x0":
            target = x_start
        elif self.objective == "pred_v":
            v = self.predict_v(x_start, t, noise)
            target = v
        else:
            raise ValueError(f"unknown objective {self.objective}")

        loss = F.mse_loss(model_out, target, reduction="none")
        loss = reduce(loss, "b ... -> b", "mean")

        loss = loss * extract(self.loss_weight, t, loss.shape)
        return loss.mean()

    def forward(self, img, *args, **kwargs):
        (
            b,
            c,
            h,
            w,
            device,
            img_size,
        ) = (
            *img.shape,
            img.device,
            self.image_size,
        )
        assert (
            h == img_size and w == img_size
        ), f"height and width of image must be {img_size}"
        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

        img = self.normalize(img)
        return self.p_losses(img, t, *args, **kwargs)
