from ast import arg
from functools import partial
import sys, os

import argparse
from datetime import datetime
from tkinter.tix import Tree
from turtle import window_height
import torch
from tensorboardX import SummaryWriter
import random
import numpy as np
import signal
import math
import json
import torch.nn as nn
from timeit import default_timer as timer
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

import config as cfg
from utils import gradient_utils
from utils.time_statistics import TimeStatistics
from nnutils.geometry import augment_grid
from nnutils.learningrate import adjust_learning_rate, StepLearningRateSchedule, get_learning_rates

from dataset.dataset_deform4d_animal_abitraryflow import MeshDataset as Dataset
from flow.model_arbitrary import FlowArbitrary
from flow.loss_arbitrary import LossFlow
from flow.evaluate_arbitrary import evaluate


def main():
    torch.set_num_threads(cfg.num_threads)
    
    # torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    # Parse command line arguments.
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', action='store', dest='data', help='Provide a subfolder with training data')
    parser.add_argument('--experiment', action='store', dest='experiment', help='Provide an experiment name')
    parser.add_argument('--interval', type=int, default=3, help='the interval of pair')
    parser.add_argument('--ngpus', type=int, default=1, help='the number of gpus')
    parser.add_argument('--dataug', action='store_true', help='if use data augmentation')
    parser.add_argument('--aug_type', type=str, default='rotation', help='augmentation type')
    parser.add_argument('--partial_range', type=float, default=0.1, help='partial range')
    parser.add_argument('--num_surf_samples', type=int, default=5000, help='the number of sampled point cloud as input')
    
    parser.add_argument('--model_cano', type=str, default='', help='')
    parser.add_argument('--model_deform', type=str, default='', help='')
    parser.add_argument('--use_normals', action='store_true', help='if use surface normals')
    parser.add_argument('--no_transformer', action='store_true', help='if not use transformer as encoder in airnet model')
    parser.add_argument('--interp_dec', action='store_true', help='if use interpolate decoder in airnet model')
    parser.add_argument('--global_field', action='store_true', help='if only use global latent code in airnet model')

    parser.add_argument('--batch_size', type=int, default=32, help='batch size')
    parser.add_argument('--init_lr', type=float, default=0.001, help='initial learning rate')
    parser.add_argument('--step', type=int, default=100, help='decay steps')
    parser.add_argument('--load_model', action='store_true', help='if load pretrained model')
    args = parser.parse_args()

    # Train set on which to actually train
    data = args.data

    # Experiment
    experiment_name = args.experiment
    
    use_augmentation = args.dataug
    
    augmentation_type = args.aug_type

    partial_range = args.partial_range
    
    use_normals = args.use_normals
    
    #no_input_corr = args.no_input_corr 
    
    inverse = False #args.inverse
    
    num_surf_samples = args.num_surf_samples
    

    if cfg.initialize_from_other:
        print("Will initialize from provided checkpoint")
    else:
        print("Will train from scratch")
    print()

    # Print hyperparameters
    cfg.print_hyperparams(data, experiment_name)

    print()

    #####################################################################################
    # Creating tf writer and folders 
    #####################################################################################
    data_dir = os.path.join(cfg.data_deform4d_seq_root_dir, data)
    experiment_dir = os.path.join(cfg.experiments_dir, experiment_name)
    checkpoints_dir = None
    
    # Writer initialization.
    log_dir = os.path.join(experiment_dir, "tf_run")

    #train_log_dir = log_dir + "/" + data
    train_log_dir = log_dir + "/" + "train"
    if not os.path.exists(train_log_dir): os.makedirs(train_log_dir, exist_ok=True)
    train_writer = SummaryWriter(train_log_dir)
    
    val_log_dir = log_dir + "/" + "val"
    if not os.path.exists(val_log_dir): os.makedirs(val_log_dir, exist_ok=True)
    val_writer = SummaryWriter(val_log_dir)

    # Creation of model output directories.
    checkpoints_dir = os.path.join(experiment_dir, "checkpoints")   
    if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
    

    # We count the execution time between evaluations.
    time_statistics = TimeStatistics()

    #####################################################################################
    # Create datasets and dataloaders
    #####################################################################################
    # Augmentation is currently not supported for shape training.
    train_dataset = Dataset(
        data_dir, cfg.flow_num_point_samples, 
        cache_data=False, use_augmentation=use_augmentation, 
        augmentation_type=augmentation_type,
        partial_range = partial_range,
        iden_split = cfg.iden_seen_split,
        split = cfg.train_split,
        interval = args.interval, 
        use_normals = use_normals,
        inverse = inverse,
        num_surf_samples = num_surf_samples,
    ) 
    val_dataset = Dataset(
        data_dir, cfg.flow_num_point_samples, 
        cache_data=False, use_augmentation=False, 
        partial_range = partial_range,
        iden_split = cfg.iden_seen_split,
        split = cfg.val_unseen_motion_split,
        interval = args.interval, 
        use_normals = use_normals,
        inverse = inverse,
        num_surf_samples = num_surf_samples,
    )

    cfg.flow_batch_size = args.batch_size
    cfg.flow_batch_size = cfg.flow_batch_size * args.ngpus
    print('Real flow batch size', cfg.flow_batch_size)
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=cfg.flow_batch_size, shuffle=cfg.shuffle, num_workers=cfg.num_worker_threads, pin_memory=False
    )
    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset, batch_size=cfg.flow_batch_size, shuffle=False, num_workers=cfg.num_worker_threads, pin_memory=False
    )

    print("Num. training samples: {0}".format(len(train_dataset)))
    print()

    if len(train_dataset) < cfg.flow_batch_size:
        print()
        print("Reduce the batch_size, since we only have {} training samples but you indicated a batch_size of {}".format(
            len(train_dataset), cfg.flow_batch_size)
        )
        exit()

    #####################################################################################
    # Initializing: model, criterion, optimizer...
    #####################################################################################
    # Set the iteration number
    iteration_number = 0


    model_canonicalize = Flow(use_normals=use_normals, no_input_corr=True, \
        no_transformer=args.no_transformer, interp_dec=args.interp_dec, \
        global_field=args.global_field).cuda()
    pretrained_dict = torch.load(args.model_cano)['model_state_dict']
    model_canonicalize.load_state_dict(pretrained_dict)
    print('load model canonicalize from :', args.model_cano)

    model_deform = Flow(use_normals=use_normals, no_input_corr=False, \
        no_transformer=args.no_transformer, interp_dec=args.interp_dec, \
        global_field=args.global_field).cuda()
    pretrained_dict = torch.load(args.model_deform)['model_state_dict']
    model_deform.load_state_dict(pretrained_dict)
    print('load model deform from :', args.model_deform)
    model = FlowArbitrary(model_canonicalize=model_canonicalize, model_deform=model_deform).cuda()
    
    # Maybe load pretrained model
    cfg.initialize_from_other = args.load_model
    if cfg.initialize_from_other:
        cfg.saved_model_path = "out/experiments/%s/checkpoints/lastest_model.pt"%experiment_name
        print("Initializing from model: ", cfg.saved_model_path)
        print()

        iteration_number = cfg.saved_model_iteration + 1

        # Load pretrained dict
        pretrained_dict = torch.load(cfg.saved_model_path)['model_state_dict']
        model.load_state_dict(pretrained_dict)
        
        # load last epoch
        start_epoch = torch.load(cfg.saved_model_path)['epoch']

        try:
            best_val_loss = torch.load(cfg.saved_model_path)['best_val_flow_loss']
        except:
            best_val_loss = 100000000
    else:
        # initilize last epoch
        start_epoch = 0

        best_val_loss = 100000000
        
    
    # multiple-GPUs
    if args.ngpus > 1:
        model = nn.DataParallel(model)
        print("Using Multiple GPU :", args.ngpus)
    else:
        print("Using Single GPU")
        
    # Criterion.
    criterion = LossFlow()

    # Count parameters.
    n_all_model_params = int(sum([np.prod(p.size()) for p in model.parameters()]))
    n_trainable_model_params = int(sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())]))
    print("Number of parameters: {0} / {1}".format(n_trainable_model_params, n_all_model_params))
    print()

    # Set up optimizer.
    interval = args.step
    print('learning rate decay interval :', interval)
    factor = 0.1
    init_lr = args.init_lr
    lr_schedule_network = StepLearningRateSchedule({
            "type": "step",
            "initial": init_lr,
            "interval": interval,
            "factor": factor,
        },)
    lr_schedules = [lr_schedule_network]

    optimizer = torch.optim.Adam([
        {
            "params": model.parameters(),
            "lr": lr_schedule_network.get_learning_rate(0),
        }]
    )
    
    # Initialize training.
    train_writer.add_text("hyperparams",
                    "Training data: " + data
                    + ",\nBatch size: " + str(cfg.flow_batch_size)
                    + ",\nLearning rate:" + str(init_lr)
                    + ",\nEpochs: " + str(cfg.epochs))

    # Execute training.
    complete_cycle_start = timer()

    #for epoch in range(0, cfg.epochs): instead of from 0 to begin
    for epoch in range(start_epoch, cfg.epochs):
        # adjust learning rate
        adjust_learning_rate(lr_schedules, optimizer, epoch)
        model.train()
        
        loss_total_sum = 0.0
        loss_flow_sum  = 0.0
        count = 0
        
        if True:
            for i, data in enumerate(train_dataloader):
                model.train()

                #####################################################################################
                ####################################### Train #######################################
                #####################################################################################
                
                #####################################################################################
                # Data loading
                #####################################################################################
                uniform_samples, near_surface_samples, surface_samples, flow_samples, grid, surface_normals, _, rotated2gaps, bbox_lower, bbox_upper, temp_idx, sample_idx = data
                surface_normals         = train_dataset.unpack(surface_normals).cuda()

                uniform_samples         = train_dataset.unpack(uniform_samples).cuda()
                near_surface_samples    = train_dataset.unpack(near_surface_samples).cuda()
                surface_samples         = train_dataset.unpack(surface_samples).cuda()
                flow_samples            = train_dataset.unpack(flow_samples).cuda()
                grid                    = train_dataset.unpack(grid).cuda()
                rotated2gaps            = train_dataset.unpack(rotated2gaps).cuda()

                # Merge uniform and near surface samples.
                batch_size = uniform_samples.shape[0]

                train_batch_start = timer()

                #####################################################################################
                # Forward pass.
                #####################################################################################
                train_batch_forward_pass = timer()

                try:
                    if use_normals:
                        query_points0_canonicalize, surface_samples0_canonicalize, query_points1_deformed = model(flow_samples, surface_samples[:, 0:2], surface_samples[:, 2], surface_normals)
                    else:
                        query_points0_canonicalize, surface_samples0_canonicalize, query_points1_deformed = model(flow_samples, surface_samples[:, 0:2], surface_samples[:, 2])

                except RuntimeError as e:
                    print("Runtime error!", e)
                    print("Exiting...")
                    exit()
                
                time_statistics.forward_duration += (timer() - train_batch_forward_pass)

                #####################################################################################
                # Loss.
                #####################################################################################
                train_batch_loss_eval = timer()


                loss_total, loss_flow = criterion(flow_samples[:, 1], query_points1_deformed)
                loss_total_sum  += loss_total.item()
                loss_flow_sum   += loss_flow.item()
                count += 1

                if iteration_number % 10 ==0:
                    sys.stdout.write("\r############# Train iteration: {0} (of Epoch {1}) || Loss: {2} || Loss_flow: {3} || Experiment: {4}".format(
                        iteration_number, epoch, loss_total.item(), loss_flow.item(),  experiment_name)
                    )
                    sys.stdout.flush()
                
                if cfg.detect_anomaly:
                    if not np.isfinite(loss_total.item()):
                        print("Non-finite loss: {}".format(loss_total.item()))
                        exit()

                time_statistics.loss_eval_duration += (timer() - train_batch_loss_eval)

                #####################################################################################
                # Backprop.
                #####################################################################################
                train_batch_backprop = timer()

                # Backprop
                optimizer.zero_grad()
                loss_total.backward()
                optimizer.step()
                    
                time_statistics.backward_duration += (timer() - train_batch_backprop)
                time_statistics.train_duration += (timer() - train_batch_start)

                iteration_number = iteration_number + 1
            
            # Save current code chckpoint.
            loss_total_avg = loss_total_sum / float(count)
            loss_flow_avg = loss_flow_sum / float(count)
            train_writer.add_scalar('Loss',                loss_total_avg,      epoch)
            train_writer.add_scalar('Flow',                loss_flow_avg,       epoch)
            
            print()
            print()
            print("Epoch number {0}, Iteration number {1}".format(epoch, iteration_number))
            print("{:<50} {}".format("Current TRAIN Loss   TOTAL",      loss_total_avg))
            print("{:<50} {}".format("Current TRAIN Loss   FLOW",       loss_flow_avg))
            print()
            
        #####################################################################################
        #################################### Evaluation #####################################
        #####################################################################################
        #if cfg.do_evaluation and iteration_number % cfg.evaluation_flow_frequency == 0:
        model.eval()

        eval_start = timer()

        # Compute train metrics.
        num_samples = len(val_dataset)
        
        num_eval_batches = math.ceil(num_samples / cfg.flow_batch_size) # We evaluate on approximately 1000 samples.

        eval_losses, eval_metrics = evaluate(model, criterion, val_dataloader, num_eval_batches, use_normals = use_normals)

        # Save current model checkpoint.
        if eval_losses["flow"] < best_val_loss:
            best_val_loss = eval_losses["flow"]
            output_checkpoint_path = os.path.join(checkpoints_dir, "Epoch{0:04d}_model.pt".format(epoch))
            if args.ngpus > 1:
                torch.save({
                    'epoch': epoch, 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    'model_state_dict': model.module.state_dict(),
                    'best_val_flow_loss': best_val_loss,
                    }, 
                    output_checkpoint_path
                )
            else:
                torch.save({
                    'epoch': epoch, 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    'model_state_dict': model.state_dict(),
                    'best_val_flow_loss': best_val_loss,
                    }, 
                    output_checkpoint_path
                )

        lastest_checkpoint_path = os.path.join(checkpoints_dir, "lastest_model.pt")       
        torch.save({
                'epoch': epoch, 
                'optimizer_state_dict': optimizer.state_dict(), 
                'model_state_dict': model.state_dict(),
                'best_val_flow_loss': best_val_loss,
                }, 
                lastest_checkpoint_path
            )         
        
        # Save current code chckpoint.
        val_writer.add_scalar('Loss',                eval_losses["total"],      epoch)
        val_writer.add_scalar('Flow',                eval_losses["flow"],       epoch)

        print()
        print()
        print("Epoch number {0}, Iteration number {1}".format(epoch, iteration_number))
        print("{:<50} {}".format("Current EVAL Loss   TOTAL",      eval_losses["total"]))
        print("{:<50} {}".format("Current EVAL Loss   FLOW",       eval_losses["flow"]))
        print()


        time_statistics.eval_duration = timer() - eval_start

        # We compute the time of IO as the complete time, subtracted by all processing time.
        time_statistics.io_duration += (timer() - complete_cycle_start - time_statistics.train_duration - time_statistics.eval_duration)
        
        # Set CUDA_LAUNCH_BLOCKING=1 environmental variable for reliable timings. 
        print("Cycle duration (s): {0:3f} (IO: {1:3f}, TRAIN: {2:3f}, EVAL: {3:3f})".format(
            timer() - time_statistics.start_time, time_statistics.io_duration, time_statistics.train_duration, time_statistics.eval_duration
        ))
        print("FORWARD: {0:3f}, LOSS: {1:3f}, BACKWARD: {2:3f}".format(
            time_statistics.forward_duration, time_statistics.loss_eval_duration, time_statistics.backward_duration
        ))                       

        print()

        time_statistics = TimeStatistics()
        complete_cycle_start = timer()
        
        sys.stdout.flush()

    train_writer.close()
    val_writer.close()

    print()
    print("I'm done")


if __name__=="__main__":
    main()