import os
import argparse
import pickle as pkl
import time

# third party
import numpy as np
import nibabel as nib
import torch
from tqdm import tqdm
from torch.nn import functional as F
# import voxelmorph with pytorch backend
os.environ['NEURITE_BACKEND'] = 'pytorch'
os.environ['VXM_BACKEND'] = 'pytorch'
import voxelmorph as vxm   # nopep8
import evalutils
from kleindataloader import KleinDatasets

# parse commandline args
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, help='keras model for nonlinear registration')
parser.add_argument('--warp', help='output warp deformation filename')
parser.add_argument('--dry_run', action='store_true', help='dry run')
parser.add_argument('--gpu', type=str, default="0")
parser.add_argument('--multichannel', action='store_true',
                    help='specify that data has multiple channels')
args = parser.parse_args()

# device handling
if args.gpu and (args.gpu != '-1'):
    device = 'cuda'
else:
    device = 'cpu'

# load model
# model = vxm.networks.VxmDense.load(args.model, input_model=None)
# weights = model.get_weights()
# import importlib
# import torch
# os.environ['NEURITE_BACKEND'] = 'pytorch'
# os.environ['VXM_BACKEND'] = 'pytorch'
# importlib.reload(vxm)

if "vxm" in args.model:
    feat = [[16, 32, 32, 32], [32, 32, 32, 32, 32, 16, 16]]
else:
    if 'dice' in args.model:
        feat = [[256, 256, 256, 256], [256, 256, 256, 256, 256, 256]]
    else:
        feat = [[64]*4, [64]*6]

# build pytorch model 
reg_args = dict(inshape=(160, 192, 224), 
                int_steps=5,
                int_downsize=2,
                unet_half_res=True,
                nb_unet_features=feat)
model = vxm.networks.VxmDense(**reg_args)
model.load_state_dict(torch.load(args.model))
model.to(device)
model.eval()
# i = 0
# i_max = len(list(model.named_parameters()))
# torchparam = model.state_dict()
# for k, v in torchparam.items():
#     if i < i_max:
#         print("{:20s} {}".format(k, v.shape))
#         if k.split('.')[-1] == 'weight':
#             # torchparam[k] = torch.tensor(weights[i].T)
#             # torchparam[k] = torch.tensor(torch.movedim(weights[i], (-1, -2), (0, 1)))
#             # print(weights[i].shape)
#             # torchparam[k] = torch.tensor(np.movedim(weights[i], (-1, -2), (0, 1)))
#             # torchparam[k] = torch.movedim(torch.movedim(torch.tensor(weights[i]), -1, -2), 0, 1)
#             torchparam[k] = torch.permute(torch.tensor(weights[i]), (4, 3, 0, 1, 2))
#         else:
#             torchparam[k] = torch.tensor(weights[i])
#         i += 1
# model.load_state_dict(torchparam)
# model.to(device)
# model.eval()
# torch.save(model.state_dict(), args.model.replace(".h5", ".pth"))
# exit()

@torch.no_grad()
def main(dataset, isotropic, crop, savepath):
    if os.path.exists(savepath):
        print("Save path already exists")
        return
    print("Loading dataset...", dataset, isotropic, crop)
    dataset = KleinDatasets(dataset=dataset, isotropic=isotropic, crop=crop, dry_run=args.dry_run)
    results_dict = {}
    inshape = dataset.getimgsize()
    # print(device)
    # with tf.device(device):
    # model = vxm.networks.VxmDense.load(args.model, device)
    # run results
    for i, batch in tqdm(enumerate(dataset), total=len(dataset)):
        # [1, H, W, D]
        moving_img, fixed_img, moving_seg, fixed_seg = batch
        moving_img, fixed_img, moving_seg, fixed_seg = moving_img.to(device), fixed_img.to(device), moving_seg.to(device), fixed_seg.to(device)
        moving_img, fixed_img = moving_img[None], fixed_img[None]
        # convert moving and fixedseg
        maxlabel = int(max(moving_seg.max(), fixed_seg.max()))
        # print(maxlabel, moving_seg.shape)
        moving_seg = F.one_hot(moving_seg, num_classes=maxlabel+1)[..., 1:]
        moving_seg = moving_seg.permute(0, 4, 1, 2, 3).float()
        fixed_seg = F.one_hot(fixed_seg, num_classes=maxlabel+1)[..., 1:]
        fixed_seg = fixed_seg.permute(0, 4, 1, 2, 3).float()

        # a = time.time()
        _, warp = model(moving_img, fixed_img, registration=True)
        moved_seg = model.transformer(moving_seg, warp)
        # print(warp.shape, moved_seg.shape, moved_seg.min(), moved_seg.max())
        # moved_seg = transform.predict([moving_seg.numpy(), warp])
        # b = time.time()
        # print(b - a)
        # print shape
        moved_seg = (moved_seg>=0.5).float()
        fid, mid = dataset.pair_ids[i]
        ret = evalutils.compute_metrics(moved_seg, fixed_seg, warp, onlydice=True, labelmax=maxlabel)
        results_dict[(fid, mid)] = ret
        print({k: (np.mean(v), np.array(v).shape) for k, v in ret.items()})
    
    # save results
    print(f"Saving results to {savepath}.")
    with open(savepath, 'wb') as fi:
        pkl.dump(results_dict, fi)


if __name__ == "__main__":
    modelstr = "vxm" if "vxm" in args.model else "sxm" 
    # os.makedirs(f"resultsklein_{modelstr}", exist_ok=True)
    os.makedirs("resultsklein", exist_ok=True)
    for dataset in ['IBSR18', 'CUMC12', 'MGH10', 'LPBA40']:
        for isotropic in [True, False]:
            for crop in [True, False]:
                isostr = "isotropic" if isotropic else "anisotropic"
                cropstr = "crop" if crop else "nocrop"
                savepath = f"resultsklein/{dataset}_{isostr}_{cropstr}.pkl"
                # savepath = f"resultsklein_{modelstr}/{dataset}_{isostr}_{cropstr}.pkl"
                try:
                    main(dataset, isotropic=isotropic, crop=crop, savepath=savepath)
                except Exception as e:
                    print("Failed", dataset, isotropic, crop, e)
                # try:
                #     main(dataset, isotropic, crop, savepath)
                #     print("Done", dataset, isotropic, crop)
                # except Exception as e:
                #     print("Failed", dataset, isotropic, crop, e)
                #     print(e)
