import math
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
import wandb
from tqdm import tqdm

def flat_mean(x, start_dim=1):
    reduce_dim = [i for i in range(start_dim, x.ndim)]
    return torch.mean(x, dim=reduce_dim)

def get_indices_of_multi_active_dims(samples, mu):
    copy_samples = np.copy(samples)
    copy_samples[np.abs(copy_samples) < 0.1] = 0.0
 
    non_zero_counts = np.count_nonzero(copy_samples, axis=1)
    multi_dim_active = np.where(non_zero_counts > 1)[0]
    count_5 = 0
    indices = []
    for idx, i in enumerate(multi_dim_active):
        # data[i] is sample where 2 dim are active.
        # Find the dimensions where they are active -- non-zero
        non_zero_dims = np.where(copy_samples[i] != 0)[0]
        if len(non_zero_dims) == 0:
            continue
        # if len(non_zero_dims) == 1:
            
        # Find the value closest to mean
        dist_to_mean = np.abs(copy_samples[i][non_zero_dims] - mu[non_zero_dims])
        # Count the number of samples greater than 4 with dist_to_mean
        count_5 = np.where(dist_to_mean > 5)[0]
        if len(count_5) > 0:
            indices.append(i)
    return multi_dim_active, indices

def plot_histogram(single_active, multi_active, name, args):
    # Plot histogram of single active and multi active
    plt.hist(single_active, bins=10, alpha=0.5, label='In-Support')
    plt.hist(multi_active, bins=10, alpha=0.5, label='Hall/Out-of-Support')
    plt.yscale('log')
    plt.legend(loc='upper right')
    plt.savefig(f"{args.chkpt_dir}/{args.store_name}/{name}.png")
    if args.log_results:
        wandb.log({f"{name}": wandb.Image(f"{args.chkpt_dir}/{args.store_name}/{name}.png")})
    plt.close()


def compute_metrics(gen_dataset, diffusion, model, evaluator, gen, args):
    multi_active, outlier_indices = get_indices_of_multi_active_dims(gen_dataset, evaluator.mu)
    multi_active_samples = gen_dataset[multi_active]
    # print(multi_active.shape, len(multi_active))
    single_active = list(set(range(len(gen_dataset))) - set(multi_active))
    # Reconstruction loss
    recon_loss, bpd_loss = reconstruction_loss(gen_dataset, diffusion, model, evaluator, gen, args)
    # Possible metrics:
    # Take mean of all timesteps
    # print(single_active, multi_active)
    mean_recon_loss = np.mean(recon_loss, axis=1)
    mean_bpd_loss = np.mean(bpd_loss, axis=1)
    plot_histogram(mean_recon_loss[single_active], mean_recon_loss[multi_active], "recon_loss_mean", args)
    plot_histogram(mean_bpd_loss[single_active], mean_bpd_loss[multi_active], "bpd_loss_mean", args)
    for idx, t in enumerate(range(50, 500, 100)):
        plot_histogram(recon_loss[:, idx], bpd_loss[:, idx], f"t_{t}_recon_loss", args)
        plot_histogram(recon_loss[:, idx], bpd_loss[:, idx], f"t_{t}_bpd_loss", args)

    return recon_loss, bpd_loss

def reconstruction_loss(gen_dataset, diffusion, model, evaluator, gen, args):
    # Reconstruction loss
    all_recon_list = []
    all_bpd_list = []
    B = 10000
    for idx in tqdm(range(len(gen_dataset)//B)):
        sample = gen_dataset[B*idx:B*(idx+1)]
        x_0 = sample
        x_0 = torch.from_numpy(x_0).to(args.device)#.unsqueeze(0)
        recon_list = []
        bpd_list = []
        for t in range(50, 500, 100):
            # print(t)
            noise = torch.randn_like(x_0)
            time_step = torch.tensor([t]).to(args.device)
            x_t = diffusion.q_sample(x_0, time_step, noise=noise)
            bpd_loss = diffusion._loss_term_bpd(model, x_0, x_t, t=time_step, clip_denoised=False, return_pred=False)
            model_out = model(x_t, time_step)
            recon_loss = flat_mean((noise - model_out).pow(2))
            recon_list.append(recon_loss.detach().cpu().numpy())  
            bpd_list.append(bpd_loss.detach().cpu().numpy())
        all_recon_list.append(recon_list)
        all_bpd_list.append(bpd_list)
    all_recon_list = np.array(all_recon_list).reshape(-1, 5)
    all_bpd_list = np.array(all_bpd_list).reshape(-1, 5)
    print(all_recon_list.shape, all_bpd_list.shape)
    # all_recon_list_mean = np.mean(all_recon_list, axis=0)
    # all_bpd_list = np.mean(all_bpd_list, axis=0)
    return all_recon_list, all_bpd_list

def filter_variance(x_gen, all_predx0):
    start_timestep = 50
    # end_timestep = 100
    variance = np.mean(np.var(all_predx0[:, -start_timestep], axis=1))
    # Sort the variance
    sorted_indices = np.argsort(variance)
