import torch

import argparse
import warnings
import time
import numpy as np

from tabsyn.tabsyn.model import MLPDiffusion, Model, Classifier
from tabsyn.tabsyn.latent_utils import get_input_generate, pipeline_get_input_generate, recover_data, split_num_cat_target
from tabsyn.tabsyn.diffusion_utils import sample

warnings.filterwarnings('ignore')


def cond_sample(args):
    device = args['device']
    save_path = args['save_path']
    classifier_ckpt_path = args['classifier_ckpt_path']

    train_z, ckpt_path, num_inverse, cat_inverse = pipeline_get_input_generate(args)
    in_dim = train_z.shape[1]

    mean = train_z.mean(0)

    denoise_fn = MLPDiffusion(in_dim, 1024).to(device)
    
    model = Model(denoise_fn = denoise_fn, hid_dim = train_z.shape[1]).to(device)

    model.load_state_dict(torch.load(f'{ckpt_path}/model.pt'))

    if classifier_ckpt_path is not None:
        # y = torch.randint(0, 10, (train_z.shape[0], 1)).to(device)
        if 'labels' in args:
            y = args['labels']
        else:
            y = np.load(args['label_path'])
        y = torch.from_numpy(y).long().to(device)
        hidden_sizes=[
            256, 512, 1024, 2048, 1024, 512, 256
        ]
        hidden_sizes = [1 * i for i in hidden_sizes]

        # IMPORTANT:
        # max(y) here does not necessarily match the max(y) during training.
        # Some classes may be in the training labels, but was never sampled.
        # Need to obtain the original y to recover the correct size.
        classifier = Classifier(
            d_in=in_dim,
            d_out=args['num_classes'],
            dim_t=256,
            hidden_sizes=hidden_sizes
        ).to(device)
        classifier.load_state_dict(torch.load(f'{classifier_ckpt_path}/model.pt'))
        classifier.eval()

    '''
        Generating samples    
    '''
    start_time = time.time()

    num_samples = train_z.shape[0]
    if classifier_ckpt_path is not None:
        num_samples = y.shape[0]
    sample_dim = in_dim

    # Initialize an empty array to store synthetic data
    syn_data_batches = []

    # Calculate the number of batches needed
    num_batches = (num_samples + args['sample_batch_size'] - 1) // args['sample_batch_size']

    correct = 0
    for i in range(num_batches):
        print('Batch:', i + 1, '/', num_batches)
        # Calculate start and end indices for the current batch
        start_idx = i * args['sample_batch_size']
        end_idx = min((i + 1) * args['sample_batch_size'], num_samples)
        
        # Adjust batch_size for the last batch
        current_batch_size = end_idx - start_idx
        
        # Sample for the current batch
        if classifier_ckpt_path is not None:
            x_next = sample(
                model.denoise_fn_D, 
                current_batch_size, 
                sample_dim, 
                args['num_steps'],
                classifier=classifier,
                y=y[start_idx:end_idx],
                classifier_scale=args['classifier_scale']
            )
        else:
            x_next = sample(
                model.denoise_fn_D, 
                current_batch_size, 
                sample_dim,
                args['num_steps']
            )
        
        if classifier_ckpt_path is not None:
            with torch.no_grad():
                pred = classifier(x_next, timesteps=torch.zeros(x_next.shape[0]).to(device))
                correct += (pred.argmax(dim=1) == y[start_idx:end_idx]).sum().item()
        
        x_next = x_next * 2 + mean.to(device)
        
        # Convert to numpy and store
        syn_data = x_next.float().cpu().numpy()
        if np.isnan(syn_data).any():
            print('Found NaN values in the synthetic data. Skipping batch...')
        syn_data_batches.append(syn_data)

    print(f'acc: {correct / num_samples}')

    # Initialize lists to hold the batched results
    all_syn_num = []
    all_syn_cat = []
    all_syn_target = []

    # Process each batch
    for batch in syn_data_batches:
        # Concatenate up to the current number of samples needed, ensuring we don't exceed num_samples
        current_batch = batch[:min(args['sample_batch_size'], num_samples - len(all_syn_num))]
        if len(current_batch) == 0:
            break  # Break if we've already accumulated enough samples

        # Split the current batch
        syn_num, syn_cat, syn_target = split_num_cat_target(current_batch, args['info'], num_inverse, cat_inverse)

        # Append the results
        all_syn_num.append(syn_num)
        all_syn_cat.append(syn_cat)
        all_syn_target.append(syn_target)

        if len(all_syn_num) >= num_samples:
            break  # Break if we've accumulated enough samples

    # Concatenate all results
    final_syn_num = np.concatenate(all_syn_num, axis=0)
    final_syn_cat = np.concatenate(all_syn_cat, axis=0)
    final_syn_target = np.concatenate(all_syn_target, axis=0)

    syn_df = recover_data(final_syn_num, final_syn_cat, final_syn_target, args['info'])
    idx_name_mapping = args['info']['idx_name_mapping']
    idx_name_mapping = {int(key): value for key, value in idx_name_mapping.items()}

    syn_df.rename(columns = idx_name_mapping, inplace=True)

    if save_path is not None:
        syn_df.to_csv(save_path, index = False)
    
    end_time = time.time()
    print('Time:', end_time - start_time)

    if save_path is not None:
        print('Saving sampled data to {}'.format(save_path))
    return syn_df


def main(args):
    dataname = args.dataname
    device = args.device
    steps = args.steps
    save_path = args.save_path
    classifier_ckpt_path = args.classifier_save_path

    train_z, _, _, ckpt_path, info, num_inverse, cat_inverse = get_input_generate(args)
    in_dim = train_z.shape[1]

    mean = train_z.mean(0)

    denoise_fn = MLPDiffusion(in_dim, 1024).to(device)
    
    model = Model(denoise_fn = denoise_fn, hid_dim = train_z.shape[1]).to(device)

    model.load_state_dict(torch.load(f'{ckpt_path}/model.pt'))

    if classifier_ckpt_path is not None:
        # y = torch.randint(0, 10, (train_z.shape[0], 1)).to(device)
        y = np.load(args.label_path)
        y = torch.from_numpy(y).long().to(device)
        hidden_sizes=[
            256, 512, 1024, 2048, 1024, 512, 256
        ]
        hidden_sizes = [1 * i for i in hidden_sizes]
        classifier = Classifier(
            d_in=in_dim,
            d_out=int(max(y) + 1),
            dim_t=256,
            hidden_sizes=hidden_sizes
        ).to(device)
        classifier.load_state_dict(torch.load(f'{classifier_ckpt_path}/model.pt'))
        classifier.eval()

    '''
        Generating samples    
    '''
    start_time = time.time()

    num_samples = train_z.shape[0]
    sample_dim = in_dim

    # Initialize an empty array to store synthetic data
    syn_data_batches = []

    # Calculate the number of batches needed
    num_batches = (num_samples + args.sample_batch_size - 1) // args.sample_batch_size

    correct = 0
    for i in range(num_batches):
        print('Batch:', i + 1, '/', num_batches)
        # Calculate start and end indices for the current batch
        start_idx = i * args.sample_batch_size
        end_idx = min((i + 1) * args.sample_batch_size, num_samples)
        
        # Adjust batch_size for the last batch
        current_batch_size = end_idx - start_idx
        
        # Sample for the current batch
        if classifier_ckpt_path is not None:
            x_next = sample(
                model.denoise_fn_D, 
                current_batch_size, 
                sample_dim, 
                classifier=classifier,
                y=y[start_idx:end_idx],
                classifier_scale=0
            )
        else:
            x_next = sample(model.denoise_fn_D, current_batch_size, sample_dim)
        
        if classifier_ckpt_path is not None:
            with torch.no_grad():
                pred = classifier(x_next, timesteps=torch.zeros(x_next.shape[0]).to(device))
                correct += (pred.argmax(dim=1) == y[start_idx:end_idx]).sum().item()
        
        x_next = x_next * 2 + mean.to(device)
        
        # Convert to numpy and store
        syn_data = x_next.float().cpu().numpy()
        syn_data_batches.append(syn_data)

    print(f'acc: {correct / num_samples}')

    all_syn_data = np.concatenate(syn_data_batches, axis=0)[:num_samples, :]
    # Now, all_syn_data contains all the synthetic data
    # Split and recover data as before
    syn_num, syn_cat, syn_target = split_num_cat_target(all_syn_data, info, num_inverse, cat_inverse)

    syn_df = recover_data(syn_num, syn_cat, syn_target, info)
    idx_name_mapping = info['idx_name_mapping']
    idx_name_mapping = {int(key): value for key, value in idx_name_mapping.items()}

    syn_df.rename(columns = idx_name_mapping, inplace=True)
    syn_df.to_csv(save_path, index = False)
    
    end_time = time.time()
    print('Time:', end_time - start_time)

    print('Saving sampled data to {}'.format(save_path))

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Generation')

    parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.')
    parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
    parser.add_argument('--epoch', type=int, default=None, help='Epoch.')
    parser.add_argument('--steps', type=int, default=None, help='Number of function evaluations.')

    args = parser.parse_args()

    # check cuda
    if args.gpu != -1 and torch.cuda.is_available():
        args.device = f'cuda:{args.gpu}'
    else:
        args.device = 'cpu'