import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from piq import LPIPS
from torchvision.transforms import RandomCrop

class KarrasDenoiser:
    def __init__(
        self,
        sigma_data: float = 0.5,
        sigma_max=80.0,
        sigma_min=0.002,
        rho=7.0,
        distillation=True,
        device=None
    ):
        self.sigma_data = sigma_data
        self.sigma_max = sigma_max
        self.sigma_min = sigma_min
        self.distillation = distillation
        self.rho = rho
        self.num_timesteps = 40


    def get_scalings(self, sigma):
        # c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        # c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
        c_out = sigma / self.sigma_max
        c_skip = 1 - c_out
        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
        return c_skip, c_out, c_in

    def get_scalings_for_boundary_condition(self, sigma):
        c_skip = self.sigma_data**2 / (
            (sigma - self.sigma_min) ** 2 + self.sigma_data**2
        )
        c_out = (
            (sigma - self.sigma_min)
            * self.sigma_data
            / (sigma**2 + self.sigma_data**2) ** 0.5
        )
        c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
        return c_skip, c_out, c_in
    
    @th.no_grad()
    def heun_solver(self, samples, t, next_t, dx_fn):
        x = samples

        dx = dx_fn(x, t.long())
        samples = x + dx * (next_t - t)

        next_dx = dx_fn(samples, next_t.long())
        samples = x + (dx + next_dx) * (next_t - t) / 2

        #print(samples)

        return samples

    # @th.no_grad()
    # def euler_solver(samples, t, next_t, x0):
    #     x = samples
    #     if teacher_model is None:
    #         denoiser = x0
    #     else:
    #         denoiser = teacher_denoise_fn(x, t)
    #     d = (x - denoiser) / append_dims(t, dims)
    #     samples = x + d * append_dims(next_t - t, dims)

    #     return samples

    def get_random_times(self, batchsize, num_scales=100):

        # indices = th.randint(1, num_scales, (batchsize,1,1,1))

        # t = self.sigma_min ** (1 / self.rho) + (indices / num_scales) * (
        #     self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)
        # )
        # t = t**self.rho

        # t_next = self.sigma_min ** (1 / self.rho) + ((indices-1) / num_scales) * (
        #     self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)
        # )
        # t_next = t_next**self.rho

        t = th.randint(2, 101, (batchsize,1,1,1))
        t_next = th.randint(2, 101, (batchsize,1,1,1))
        for i in range(batchsize):
            if t[i] < t_next[i]:
                smaller = t[i].clone()
                t[i] = t_next[i]
                t_next[i] = smaller
        #t_next = t - 1

        return t, t_next

    def denoise(self, model, x_t, input, sigmas, x0_pred=None):
        import torch.distributed as dist

        if not self.distillation:
            c_skip, c_out, c_in = [
                x for x in self.get_scalings(sigmas)
            ]
        else:
            c_skip, c_out, c_in = [
                x
                for x in self.get_scalings_for_boundary_condition(sigmas)
            ]
        
        if x0_pred==None:
            if sigmas.size(0) != 1:
                sigmas_sq = th.squeeze(sigmas)
            else:
                sigmas_sq = sigmas.view(1)
            model_output = model.module.forward(x_t, sigmas_sq) #model.module.forward(x_t, input, sigmas_sq)
            denoised = c_out * model_output + c_skip * x_t
        else:
            denoised = c_out * x0_pred + c_skip * x_t

        return denoised
