import numpy as np
import torch

scheduler_config = {
    # "_class_name": "PNDMScheduler",
    # "_diffusers_version": "0.6.0",
    "num_train_timesteps": 1000,  # default: 1000
    "beta_start": 0.00085,  # default: 0.0001
    "beta_end": 0.012,  # default: 0.012
    "beta_schedule": "scaled_linear",  # default: linear (`linear`, `scaled_linear`, or `squaredcos_cap_v2`)
    # "trained_betas": None,  # default: None. Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
    "clip_sample": False,  # default: True. Clip the predicted sample for numerical stability.
    "clip_sample_range": 1.0,  # default: 1.0. The maximum magnitude for sample clipping
    "set_alpha_to_one": False,  # default: True. the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0.
    "steps_offset": 1,  # default: 0. Leading mode for timesteps
    "prediction_type": "epsilon",  # default: epsilon. `epsilon` (predicts the noise of the diffusion process),
                                   # `sample` (directly predicts the noisy sample) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
    "thresholding": False,  # default: False. Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion.
    "dynamic_thresholding_ratio": 0.995,  # default: 0.995. ratio for the dynamic thresholding method
    "sample_max_value": 1.0,  # default: 1.0. threshold value for dynamic thresholding.
    "timestep_spacing": "leading",  # default: leading. The way the timesteps should be scaled
                                    # Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)
    "rescale_betas_zero_snr": False,  # default: False. Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness.
    "skip_prk_steps": True,  # default: False. Allows the scheduler to skip the Runge-Kutta steps as being required before PLMS steps.
}


def normalize(data, new_min, new_max):
    """
    Normalize data to a new range [new_min, new_max].

    Parameters:
    data (array-like): Input data to be normalized.
    new_min (float): New minimum value of the range.
    new_max (float): New maximum value of the range.

    Returns:
    torch.Tensor: Normalized data.
    """
    # Convert data to a numpy array if it's not already
    data = np.array(data)

    # Normalize to 0-1
    data_min = np.min(data)
    data_max = np.max(data)
    data_normalized = (data - data_min) / (data_max - data_min)

    # Scale to new range [new_min, new_max]
    data_scaled = data_normalized * (new_max - new_min) + new_min

    return torch.tensor(data_scaled, dtype=torch.float32)

def rescale_zero_terminal_snr(betas):
    """
    Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)


    Args:
        betas (`torch.FloatTensor`):
            the betas that the scheduler is being initialized with.

    Returns:
        `torch.FloatTensor`: rescaled betas with zero terminal SNR
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas

def normalize_tensor(X, new_min, new_max):
    """
    Normalize a tensor to a new range [new_min, new_max].

    Parameters:
    X (numpy.ndarray): The original tensor.
    new_min (float): The lower bound of the new range.
    new_max (float): The upper bound of the new range.

    Returns:
    numpy.ndarray: The normalized tensor.
    """
    X_min = X.min()
    X_max = X.max()
    X_norm = new_min + (X - X_min) * (new_max - new_min) / (X_max - X_min)
    return X_norm

# Given alpha_bar, compute the corresponding betas.
def alphas_bar_for_betas(betas):
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    return alphas_cumprod

def get_betas_from_all_alphabar(alphabar):
    """
    Calculate betas given a sequence of alphabar values for DDPM.

    Parameters:
    alphabar (numpy.ndarray): An array of alphabar values where each
                              alphabar[t] = prod(alpha[i] for i in range(1, t+1))

    Returns:
    numpy.ndarray: An array of beta values calculated from alphabar.
    """
    # Initialize an array to hold the beta values
    betas = torch.zeros_like(alphabar)
    # Calculate alpha values from alphabar
    alphas = torch.zeros_like(alphabar)
    alphas[0] = alphabar[0]
    for t in range(1, len(alphabar)):
        alphas[t] = alphabar[t] / alphabar[t-1]
    # Calculate beta values from alpha
    betas = 1 - alphas

    return betas

def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)

def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    """
    Get a pre-defined beta schedule for the given name.

    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        scale = 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")

def logistic_alpha_bar(L=1, k=0.015, t0=300, timesteps=1000):
    """
    Args:
        L_logistic = 1  # Plateau level
        k_logistic = 0.015  # Growth rate
        t0_logistic = 300  # Midpoint of the sigmoid    
    """
    t = torch.arange(timesteps, 0, -1)
    logistic = L / (1 + np.exp(-k*(t - t0)))
    return normalize_tensor(logistic, 1e-9, 1.0-1e-9)

def geometric_schedule(beta_start, beta_end, num_steps):
    """Creates a geometric schedule for beta values."""
    ratio = (beta_end / beta_start) ** (1 / (num_steps - 1))
    betas = beta_start * (ratio ** torch.arange(num_steps))
    return betas

def exponential_schedule(beta_start, beta_end, num_steps):
    """Creates an exponential schedule for beta values."""
    betas = np.exp(np.linspace(np.log(beta_start), np.log(beta_end), num_steps))
    return torch.tensor(betas, dtype=torch.float32)

def hyperbolic_schedule(beta_start, beta_end, num_steps):
    # Scale and shift hyperbolic tangent to fit the start and end range
    x = np.linspace(-np.pi / 2, np.pi / 2, num_steps)
    betas = (np.tanh(x) + 1) / 2  # Scaled to range [0, 1]
    betas = beta_start + (beta_end - beta_start) * betas  # Scale to range [start, end]
    return torch.tensor(betas, dtype=torch.float32)

def get_betas(beta_start, beta_end, mode, num_steps=1000, **kwargs):
    if mode == 'geometric':
        betas = geometric_schedule(beta_start, beta_end, num_steps)
    elif mode == 'exponential':
        betas = exponential_schedule(beta_start, beta_end, num_steps)
    elif mode == 'hyperbolic':
        betas = hyperbolic_schedule(beta_start, beta_end, num_steps)
    elif mode == 'cosine':
        betas = get_named_beta_schedule('cosine', num_diffusion_timesteps=num_steps)
        betas = torch.tensor(betas)
    elif mode == 'logistic':
        if not kwargs.get('k', None) and not kwargs.get('t0', None):
            raise KeyError('scale `k` and center `t0` are required to initialize sigmoid-like betas')
        alphas_cumprod = logistic_alpha_bar(L=1, k=kwargs['k'], t0=kwargs['t0'], timesteps=num_steps)
        betas = get_betas_from_all_alphabar(alphas_cumprod)
    elif mode == 'scaled_linear':
        betas = (
            torch.linspace(beta_start**0.5, beta_end**0.5, num_steps, dtype=torch.float32) ** 2
        )
    elif mode == 'linear':
        betas = torch.linspace(beta_start, beta_end, num_steps, dtype=torch.float32)
    else:
        raise NotImplementedError("Scheduler mode not in 'geometric, exponential, hyperbolic, cosine, logistic, linear, scaled linear'")
    return betas


