import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import argparse
import numpy as np
import voxelmorph as vxm
import tensorflow as tf
import evalutils
from kleindataloader import KleinDatasets
from tqdm import tqdm
from torch.nn import functional as F
import torch
import os
from os import path as osp
import pickle as pkl
import time
# parse commandline args
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, help='keras model for nonlinear registration')
parser.add_argument('--warp', help='output warp deformation filename')
parser.add_argument('--dry_run', action='store_true', help='dry run')
parser.add_argument('--multichannel', action='store_true',
                    help='specify that data has multiple channels')
parser.add_argument('--gpu_id', type=int, default=0)
args = parser.parse_args()

# tensorflow device handling
device, nb_devices = vxm.tf.utils.setup_device(args.gpu_id)

# load moving and fixed images
# add_feat_axis = not args.multichannel
# # moving = vxm.py.utils.load_volfile(args.moving, add_batch_axis=True, add_feat_axis=add_feat_axis)
# fixed, fixed_affine = vxm.py.utils.load_volfile(
#     args.fixed, add_batch_axis=True, add_feat_axis=add_feat_axis, ret_affine=True)

def main(dataset, isotropic, crop, savepath):
    if osp.exists(savepath):
        print("Save path already exists")
        return
    print("Loading dataset...", dataset, isotropic, crop)
    dataset = KleinDatasets(dataset=dataset, isotropic=isotropic, crop=crop, dry_run=args.dry_run)
    results_dict = {}
    inshape = dataset.getimgsize()
    print(device)
    # with tf.device(device):
    config = dict(inshape=inshape, input_model=None)
    with tf.device(device):
        model = vxm.networks.VxmDense.load(args.model, **config)
        transform = None
        # run results
        for i, batch in tqdm(enumerate(dataset), total=len(dataset)):
            tf.keras.backend.clear_session()
            # [1, H, W, D]
            moving_img, fixed_img, moving_seg, fixed_seg = batch
            moving_img, fixed_img = moving_img.numpy()[..., None], fixed_img.numpy()[..., None]
            # convert moving and fixedseg
            maxlabel = int(max(moving_seg.max(), fixed_seg.max()))
            moving_seg = F.one_hot(moving_seg, num_classes=maxlabel+1)[..., 1:]
            fixed_seg = F.one_hot(fixed_seg, num_classes=maxlabel+1)[..., 1:]
            # print(moving_img.shape, fixed_img.shape, moving_seg.shape, fixed_seg.shape)
            # input("here")
            # create numpy arrays
            nb_feats = moving_seg.shape[-1]
            if transform is None:
                # with tf.device(device):
                transform = vxm.networks.Transform(inshape, nb_feats=nb_feats)
            # run warp
            # inshape = moving_img.shape[1:-1]
            # nb_feats = moving_img.shape[-1]
            # with tf.device(device):
            a = time.time()
            warp = model.register(moving_img, fixed_img)
            moved_seg = transform.predict([moving_seg.numpy(), warp])
            b = time.time()
            print(b - a)
            # print shape
            moved_seg = (torch.from_numpy(moved_seg)>=0.5).float()
            # print(moved_seg.shape, fixed_seg.shape, moving_seg.shape) 
            # input("hi")
            # ret = evalutils.compute_metrics(fixed_seg)
            fid, mid = dataset.pair_ids[i]
            # compute metrics
            moved_seg = moved_seg.permute(0, 4, 1, 2, 3)
            fixed_seg = fixed_seg.permute(0, 4, 1, 2, 3)
            ret = evalutils.compute_metrics(moved_seg, fixed_seg, warp, onlydice=True, labelmax=maxlabel)
            results_dict[(fid, mid)] = ret
            print({k: (np.mean(v), np.array(v).shape) for k, v in ret.items()})
        
        # save results
        print(f"Saving results to {savepath}.")
        with open(savepath, 'wb') as fi:
            pkl.dump(results_dict, fi)


if __name__ == "__main__":
    os.makedirs("resultsklein", exist_ok=True)
    for dataset in ['IBSR18', 'CUMC12', 'MGH10', 'LPBA40']:
        for isotropic in [True, False]:
            #for crop in [True, False]:
            for crop in [True]:
                isostr = "isotropic" if isotropic else "anisotropic"
                cropstr = "crop" if crop else "nocrop"
                savepath = f"resultsklein/{dataset}_{isostr}_{cropstr}.pkl"
                try:
                    main(dataset, isotropic, crop, savepath)
                    print("Done", dataset, isotropic, crop)
                except Exception as e:
                    print("Failed", dataset, isotropic, crop, e)
                    print(e)
