import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
from utils import *
from models.model import *

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, choices=['mnist', 'cifar10'], default='mnist')
parser.add_argument('--downsampling-method', type=str, default='conv', choices=['conv', 'res'])
parser.add_argument('--nepochs', type=int, default=160)
parser.add_argument('--data_aug', type=eval, default=True, choices=[True, False])
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--optim', type=str, default='sgd', choices=['adam', 'sgd'])

parser.add_argument('--save', type=str, default='./experiment1')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--criterion', type=str, default='mse', choices=['ce', 'mse'])
parser.add_argument('--label_transform', type=str, default='no', choices=[
    'no', 'onehot+zero', 'onehot+zero+spatial', 'onehot+random', 'onehot+random+spatial', 'tile'])
parser.add_argument('--detach_f_target', action='store_true', help='detach f in target from the computational graph')
parser.add_argument('--detach_f_in', action='store_true', help='detach f for input to model from the computational graph')
parser.add_argument('--num_inference_steps', type=int, default=100)
args = parser.parse_args()


def accuracy(model, dataset_loader, mode='no', num_dim=2304, num_steps=100):
    total_correct = 0
    for x, y in dataset_loader:
        x = x.to(device)
        y = one_hot(np.array(y.numpy()), 10)

        target_class = np.argmax(y, axis=1)
        model_pred = model.inference(x, num_steps=num_steps).cpu().detach()
        model_pred = Flatten()(model_pred)
        predicted_class = label_parse(model_pred, mode=mode, num_dim=num_dim).numpy()
        total_correct += np.sum(predicted_class == target_class)
    return total_correct / len(dataset_loader.dataset)


if __name__ == '__main__':

    fix_random_seeds(42)

    makedirs(args.save)
    logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__))
    logger.info(args)

    device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
    channel_in = 1 if args.data == 'mnist' else 3

    # before fc layer: BCHW=(128, 64, 6, 6)
    if args.criterion == 'ce':
        fc_layers = [norm(64), nn.ReLU(inplace=False), nn.AdaptiveAvgPool2d((1, 1)), Flatten(), nn.Linear(64, 10)]
    elif args.criterion == 'mse':
        fc_layers = [Flatten()]

    model = OurModel(downsampling_method=args.downsampling_method, detach_f_in=args.detach_f_in, channel_in=channel_in, device=device).to(device)

    logger.info(model)
    logger.info('Number of parameters: {}'.format(count_parameters(model)))

    # if args.criterion == 'ce':
    #     criterion = nn.CrossEntropyLoss().to(device)
    # elif args.criterion == 'mse':
    assert args.criterion == 'mse'
    criterion = nn.MSELoss().to(device)

    if args.data == 'mnist':
        train_loader, test_loader, train_eval_loader = get_mnist_loaders(
            args.data_aug, args.batch_size, args.test_batch_size
        )
    elif args.data == 'cifar10':
        train_loader, test_loader, train_eval_loader = get_cifar_loaders(
            args.data_aug, args.batch_size, args.test_batch_size
        )
    else:
        raise ValueError('Unknown dataset.')

    data_gen = inf_generator(train_loader)
    batches_per_epoch = len(train_loader)

    lr_fn = learning_rate_with_decay(
        args.batch_size, batch_denom=128, batches_per_epoch=batches_per_epoch, boundary_epochs=[60, 100, 140],
        decay_rates=[1, 0.1, 0.01, 0.001], base_lr=args.lr
    )

    if args.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    elif args.optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    else:
        raise ValueError('Unknown optimizer.')

    best_acc = 0
    batch_time_meter = RunningAverageMeter()
    f_nfe_meter = RunningAverageMeter()
    b_nfe_meter = RunningAverageMeter()
    end = time.time()
    num_dim = 64*6*6 if args.data == 'mnist' else 64*7*7
    for itr in range(args.nepochs * batches_per_epoch):

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_fn(itr)
        optimizer.zero_grad()
        x, y = data_gen.__next__()
        x = x.to(device)
        t = torch.rand_like(y, dtype=torch.float32, requires_grad=False).to(device).reshape(-1, 1, 1, 1)
        if args.debug:
            t = t * 0.
        y = label_transform(y, num_dim=num_dim, mode=args.label_transform).to(device)
        pred = model(t, x)

        f_x = Flatten()(model.downsampling_layers(x))
        if args.detach_f_target:
            f_x = f_x.detach()
        pred = Flatten()(pred)
        # flow matching obejective
        loss = criterion(pred, y-f_x)
        loss.backward()
        optimizer.step()

        batch_time_meter.update(time.time() - end)
        end = time.time()

        if itr % batches_per_epoch == 0:
            with torch.no_grad():
                train_acc = accuracy(model, train_eval_loader, mode=args.label_transform, num_dim=num_dim, num_steps=args.num_inference_steps)
                val_acc = accuracy(model, test_loader, mode=args.label_transform, num_dim=num_dim, num_steps=args.num_inference_steps)
                if val_acc > best_acc:
                    torch.save({'state_dict': model.state_dict(), 'args': args}, os.path.join(args.save, 'model.pth'))
                    best_acc = val_acc
                logger.info(
                    "Epoch {:04d} | Time {:.3f} ({:.3f}) | NFE-F {:.1f} | NFE-B {:.1f} | "
                    "Train Acc {:.4f} | Test Acc {:.4f}".format(
                        itr // batches_per_epoch, batch_time_meter.val, batch_time_meter.avg, f_nfe_meter.avg,
                        b_nfe_meter.avg, train_acc, val_acc
                    )
                )
    print(best_acc)
    logger.info(f"Best Acc: {best_acc}")
