import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tools import builder
from utils import misc, dist_utils
from datasets import build_dataset_from_cfg
from evals.classifier.classifier import PointNetClassifier, pc_norm
from utils.config import *
from utils.logger import *
from utils.AverageMeter import AverageMeter
from torchvision import transforms
from evals.classifier.text_queries import generate_all_queries
from evals.fid_is import compute_statistics, compute_inception_score
from modules.voxelization import voxel_to_point

train_transforms = transforms.Compose(
    [
        # data_transforms.PointcloudScale(),
        # data_transforms.PointcloudRotate(),
        # data_transforms.PointcloudRotatePerturbation(),
        # data_transforms.PointcloudTranslate(),
        # data_transforms.PointcloudJitter(),
        # data_transforms.PointcloudRandomInputDropout(),
        # data_transforms.PointcloudScaleAndTranslate(),
    ]
)


class Acc_Metric:
    def __init__(self, acc=0.):
        if type(acc).__name__ == 'dict':
            self.acc = acc['acc']
        else:
            self.acc = acc

    def better_than(self, other):
        if self.acc > other.acc:
            return True
        else:
            return False

    def state_dict(self):
        _dict = dict()
        _dict['acc'] = self.acc
        return _dict


def run_net(args, config, train_writer=None, val_writer=None):
    logger = get_logger(args.log_name)
    # build dataset
    (train_sampler, train_dataloader), (_, test_dataloader), = builder.dataset_builder(args, config.dataset.train), \
                                                               builder.dataset_builder(args, config.dataset.val)
    # build model
    config.model.with_color = config.with_color
    base_model = builder.model_builder(config.model)
    if args.use_gpu:
        base_model.to(args.local_rank)

    # parameter setting
    start_epoch = 0
    best_metrics = Acc_Metric(0.)
    metrics = Acc_Metric(0.)

    print_log('Using Data parallel ...', logger=logger)
    base_model = nn.DataParallel(base_model).cuda()

    # optimizer & scheduler
    optimizer, scheduler = builder.build_opti_sche(base_model, config)

    # trainval
    # training
    base_model.zero_grad()
    for epoch in range(start_epoch, config.max_epoch + 1):
        base_model.train()

        epoch_start_time = time.time()
        batch_start_time = time.time()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter(['Loss'])

        num_iter = 0

        n_batches = len(train_dataloader)
        for idx, (taxonomy_ids, model_ids, data, img, text, _) in enumerate(train_dataloader):
            num_iter += 1
            n_itr = epoch * n_batches + idx

            data_time.update(time.time() - batch_start_time)
            npoints = config.dataset.train.others.npoints
            dataset_name = config.dataset.train._base_.NAME
            points = data.cuda()
            img = img.cuda()

            assert points.size(1) == npoints
            points = train_transforms(points)

            base_model.zero_grad()
            loss, voxel, decoded_voxel = base_model.module.training_voxel_generator(points, img, text)

            loss.backward()
            optimizer.step()
            losses.update([loss.item() * 1000])

            if train_writer is not None:
                train_writer.add_scalar('Loss/Batch/Loss', loss.item(), n_itr)
                train_writer.add_scalar('Loss/Batch/LR', optimizer.param_groups[0]['lr'], n_itr)

            batch_time.update(time.time() - batch_start_time)
            batch_start_time = time.time()

            if idx % 20 == 0:
                print_log('[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s lr = %.6f' %
                          (epoch, config.max_epoch, idx + 1, n_batches, batch_time.val(), data_time.val(),
                           ['%.4f' % l for l in losses.val()], optimizer.param_groups[0]['lr']), logger=logger)

        scheduler.step(epoch)
        epoch_end_time = time.time()

        if train_writer is not None:
            train_writer.add_scalar('Loss/Epoch/Loss_1', losses.avg(0), epoch)
        print_log('[Training] EPOCH: %d EpochTime = %.3f (s) Losses = %s lr = %.6f' %
                  (epoch, epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()],
                   optimizer.param_groups[0]['lr']), logger=logger)

        builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-last', args, logger=logger)

        if "validate" in config.keys():
            metrics, p_fid, p_is = validate(args, config, with_color=config.with_color)
            better = metrics.better_than(best_metrics)
            # Save ckeckpoints
            if better:
                best_metrics = metrics
                builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-best',
                                        args, logger=logger)
                print_log(
                    "--------------------------------------------------------------------------------------------",
                    logger=logger)
            print_log(f"Acc: {metrics.acc} Best_Acc: {best_metrics.acc} P-FID: {p_fid} P-IS: {p_is}", logger=logger)

    if train_writer is not None:
        train_writer.close()
    if val_writer is not None:
        val_writer.close()


def validate(args, train_config, with_color=False):
    logger = get_logger("classifier")

    config = train_config.validate
    text_queries, text_labels = generate_all_queries(prefix="a")
    npoints = config.npoints
    inf_config = cfg_from_yaml_file(config.inference_cfg)
    inf_config.model.vqgan_config = train_config.model.vqgan_config
    inf_config.model.voxel_config = train_config.model.voxel_config
    inf_config.model.voxel_config.ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')
    # build model
    base_model = builder.model_builder(inf_config.model)
    base_model.cuda()
    base_model.eval()
    multiple = 10

    bs = 32
    cls_config = cfg_from_yaml_file(config.classifier_cfg)
    test_dataset = build_dataset_from_cfg(cls_config.dataset.val._base_, cls_config.dataset.val.others)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=bs * 2, shuffle=False,
                                                  drop_last=False, num_workers=8)

    classifier = PointNetClassifier(normal_channel=with_color)
    classifier.load_model_from_ckpt(cls_config.ckpt_path)
    classifier.cuda()
    classifier.eval()

    with torch.no_grad():

        gen_features = []
        gen_predictions = []
        gen_labels = []
        for i in range(multiple):
            gen_points = base_model.text_condition_generation(text_queries)
            gen_points = misc.fps(gen_points, npoints)
            gen_points = pc_norm(gen_points)
            gen_feature, gen_prediction = classifier.features_and_preds(gen_points)
            gen_features.append(gen_feature)
            gen_predictions.append(gen_prediction)
            gen_labels += text_labels

        gen_features = torch.cat(gen_features, dim=0)
        gen_predictions = torch.cat(gen_predictions, dim=0)
        gen_labels = torch.tensor(gen_labels, dtype=torch.long, device=gen_points.device)
        _, acc = classifier.get_loss_acc(gen_predictions, gen_labels)

        gt_features = []
        for idx, (taxonomy_ids, model_ids, data, _, _, _) in enumerate(test_dataloader):
            gt_points = data.cuda()
            gt_points = misc.fps(gt_points, npoints)
            gt_points = pc_norm(gt_points)

            features, _ = classifier.features_and_preds(gt_points)
            gt_features.append(features)
        gt_features = torch.cat(gt_features, dim=0)

    stats_1 = compute_statistics(gen_features.cpu().numpy())
    stats_2 = compute_statistics(gt_features.cpu().numpy())
    p_fid = stats_1.frechet_distance(stats_2)
    p_is = compute_inception_score(gen_predictions.cpu().numpy())

    return Acc_Metric(acc), p_fid, p_is
