from tabsyn.tabsyn.model import Classifier, MLPDiffusion, Model
import os
import argparse
import warnings
import torch
import time
import numpy as np
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tabsyn.tabsyn.latent_utils import get_input_generate, get_input_train, pipeline_get_input_train, recover_data, split_num_cat_target
from tabsyn.tabsyn.diffusion_utils import sample, sample_step
from torch.utils.data import DataLoader, TensorDataset, random_split
from tabsyn.tabsyn import logger

warnings.filterwarnings('ignore')

def compute_top_k(logits, labels, k, reduction="mean"):
    _, top_ks = torch.topk(logits, k, dim=-1)
    if reduction == "mean":
        return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
    elif reduction == "none":
        return (top_ks == labels[:, None]).float().sum(dim=-1)


def train_classifier(args):
    device = args['device']
    train_z, ckpt_path = pipeline_get_input_train(args)
    classifier_ckpt_path = args['classifier_save_path']
    if not os.path.exists(classifier_ckpt_path):
        os.makedirs(classifier_ckpt_path)

    labels = np.load(args['label_path'])

    start_time = time.time()

    in_dim = train_z.shape[1] 

    mean, std = train_z.mean(0), train_z.std(0)

    train_z = (train_z - mean) / 2
    train_data = TensorDataset(
        train_z, 
        torch.from_numpy(labels).long().reshape(-1, 1)
    )

    batch_size = args['batch_size']
    if args['classifier_train_split_ratio'] == 1:
        train_dataset = train_data
    else:
        train_size = int(args['classifier_train_split_ratio'] * len(train_data))
        val_size = len(train_data) - train_size
        train_dataset, val_dataset = random_split(train_data, [train_size, val_size])
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,  # Typically, you don't need to shuffle the validation data
            num_workers=4  # Adjust as per your environment
        )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4  # Adjust as per your environment
    )

    num_epochs = args['classifier_epochs']

    denoise_fn = MLPDiffusion(in_dim, 1024).to(device)
    print(denoise_fn)

    num_params = sum(p.numel() for p in denoise_fn.parameters())
    print("the number of parameters", num_params)

    model = Model(denoise_fn = denoise_fn, hid_dim = train_z.shape[1]).to(device)

    model.load_state_dict(torch.load(f'{ckpt_path}/model.pt'))
    model.eval()

    hidden_sizes=[
        256, 512, 1024, 2048, 1024, 512, 256
    ]
    hidden_sizes = [1 * i for i in hidden_sizes]
    num_classes = int(max(labels) + 1)
    classifier = Classifier(
        d_in=in_dim,
        d_out=num_classes,
        dim_t=256,
        hidden_sizes=hidden_sizes
    ).to(device)

    do_train = True
    if args['read_ckpt']:
        classifier_path = f'{classifier_ckpt_path}/model.pt'
        if os.path.exists(classifier_path):
            try:
                classifier.load_state_dict(torch.load(classifier_path))
                print('Model loaded from', classifier_path)
                do_train = False
            except:
                print('Model loading failed, train model')

    if not do_train:
        return num_classes

    classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=0.0001)
    # classifier_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
    scheduler = ReduceLROnPlateau(classifier_optimizer, mode='min', factor=0.9, patience=20, verbose=True)

    classifier.train()
    best_loss = float('inf')
    patience = 0
    for epoch in tqdm(range(num_epochs)):
        batch_loss = 0.0
        len_input = 0
        pbar = tqdm(train_loader, total=len(train_loader))
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
        for batch, batch_labels in pbar:
            batch = batch.to(device)
            batch_labels = batch_labels.long().squeeze().to(device)

            rnd_normal = torch.randn(batch.shape[0], device=device)
            sigma = (rnd_normal * model.loss_fn.P_std + model.loss_fn.P_mean).exp()

            x_curr = batch
            n = torch.randn_like(x_curr) * sigma.unsqueeze(1)
            noisy_x = x_curr + n

            logits = classifier(noisy_x, sigma)

            classifier_optimizer.zero_grad()
            loss = torch.nn.functional.cross_entropy(logits, batch_labels, reduction="mean")
            batch_loss += loss.item() * batch.shape[0]
            len_input += batch.shape[0]
            loss.backward()
            classifier_optimizer.step()

            pbar.set_postfix({"Loss": loss.item()})

        curr_loss = batch_loss / len_input
        scheduler.step(curr_loss)

        if curr_loss < best_loss:
            best_loss = loss.item()
            patience = 0
            torch.save(classifier.state_dict(), f'{classifier_ckpt_path}/model.pt')
        else:
            patience += 1
            if patience == 500:
                print('Early stopping')
                break

        if (epoch + 1) % 100 == 0:
            torch.save(classifier.state_dict(), f'{classifier_ckpt_path}/model_{epoch}.pt')

        if args['classifier_train_split_ratio'] < 1:
            if (epoch + 1) % 20 == 0:
                correct = 0
                for test_x, test_y in val_loader:
                    test_x = test_x.to(device)
                    test_y = test_y.long().squeeze().to(device)
                    
                    with torch.no_grad():
                        pred = classifier(test_x, timesteps=torch.zeros(test_x.shape[0]).to(device))
                        correct += (pred.argmax(dim=1) == test_y).sum().item()

                acc = correct / len(val_loader.dataset)
                print(acc)

    end_time = time.time()
    print('Time: ', end_time - start_time)
    return num_classes


def main(args): 
    device = args.device
    train_z, _, _, ckpt_path, _ = get_input_train(args)
    classifier_ckpt_path = args.classifier_save_path
    if not os.path.exists(classifier_ckpt_path):
        os.makedirs(classifier_ckpt_path)

    labels = np.load(args.label_path)

    start_time = time.time()

    in_dim = train_z.shape[1] 

    mean, std = train_z.mean(0), train_z.std(0)

    train_z = (train_z - mean) / 2
    train_data = TensorDataset(
        train_z, 
        torch.from_numpy(labels).long().reshape(-1, 1)
    )

    train_size = int(0.8 * len(train_data))
    val_size = len(train_data) - train_size
    train_dataset, val_dataset = random_split(train_data, [train_size, val_size])

    batch_size = args.batch_size
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4  # Adjust as per your environment
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,  # Typically, you don't need to shuffle the validation data
        num_workers=4  # Adjust as per your environment
    )

    num_epochs = args.classifier_epochs

    denoise_fn = MLPDiffusion(in_dim, 1024).to(device)
    print(denoise_fn)

    num_params = sum(p.numel() for p in denoise_fn.parameters())
    print("the number of parameters", num_params)

    model = Model(denoise_fn = denoise_fn, hid_dim = train_z.shape[1]).to(device)

    model.load_state_dict(torch.load(f'{ckpt_path}/model.pt'))
    model.eval()

    hidden_sizes=[
        256, 512, 1024, 2048, 1024, 512, 256
    ]
    hidden_sizes = [1 * i for i in hidden_sizes]
    classifier = Classifier(
        d_in=in_dim,
        d_out=int(max(labels) + 1),
        dim_t=256,
        hidden_sizes=hidden_sizes
    ).to(device)

    classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=0.0001)
    # classifier_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)
    scheduler = ReduceLROnPlateau(classifier_optimizer, mode='min', factor=0.9, patience=20, verbose=True)

    classifier.train()
    best_loss = float('inf')
    patience = 0
    for epoch in tqdm(range(num_epochs)):
        batch_loss = 0.0
        len_input = 0
        pbar = tqdm(train_loader, total=len(train_loader))
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")
        for batch, batch_labels in pbar:
            batch = batch.to(device)
            batch_labels = batch_labels.long().squeeze().to(device)

            rnd_normal = torch.randn(batch.shape[0], device=device)
            sigma = (rnd_normal * model.loss_fn.P_std + model.loss_fn.P_mean).exp()

            x_curr = batch
            n = torch.randn_like(x_curr) * sigma.unsqueeze(1)
            noisy_x = x_curr + n

            logits = classifier(noisy_x, sigma)

            classifier_optimizer.zero_grad()
            loss = torch.nn.functional.cross_entropy(logits, batch_labels, reduction="mean")
            batch_loss += loss.item() * batch.shape[0]
            len_input += batch.shape[0]
            loss.backward()
            classifier_optimizer.step()

            curr_loss = batch_loss/len_input
            scheduler.step(curr_loss)
            pbar.set_postfix({"Loss": loss.item()})

            if curr_loss < best_loss:
                best_loss = loss.item()
                patience = 0
                torch.save(classifier.state_dict(), f'{classifier_ckpt_path}/model.pt')
            else:
                patience += 1
                if patience == 500:
                    print('Early stopping')
                    break

            if (epoch + 1) % 100 == 0:
                torch.save(classifier.state_dict(), f'{classifier_ckpt_path}/model_{epoch}.pt')

        if (epoch + 1) % 20 == 0:
            correct = 0
            for test_x, test_y in val_loader:
                test_x = test_x.to(device)
                test_y = test_y.long().squeeze().to(device)
                
                with torch.no_grad():
                    pred = classifier(test_x, timesteps=torch.zeros(test_x.shape[0]).to(device))
                    correct += (pred.argmax(dim=1) == test_y).sum().item()

            acc = correct / len(val_loader.dataset)
            print(acc)

    end_time = time.time()
    print('Time: ', end_time - start_time)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Training of TabSyn')

    parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.')
    parser.add_argument('--gpu', type=int, default=0, help='GPU index.')

    args = parser.parse_args()

    # check cuda
    if args.gpu != -1 and torch.cuda.is_available():
        args.device = f'cuda:{args.gpu}'
    else:
        args.device = 'cpu'