#!/usr/bin/env python
import numpy as np
from glob import glob
from itertools import product
import SimpleITK as sitk

image_ids = range(1, 19)
labels_all = np.array([ 0,  2,  3,  4,  5,  7,  8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 26,
    28, 41, 42, 43, 44, 46, 47, 49, 50, 51, 52, 53, 54, 58, 60])

# save all dice scores
all_dices = []

def dice_score(p, q):
    num = (2.0*p*q).sum()
    den = (p.sum() + q.sum())
    return num/den

for i, (fixed_id, moving_id) in enumerate(product(image_ids, image_ids)):
    if fixed_id == moving_id:
        continue
    fixed_path = "../IBSR_{:02d}/IBSR_{:02d}_seg_ana.nii.gz".format(fixed_id, fixed_id)
    moved_path = "deformed_{:02d}_{:02d}_seg_ana.nii.gz".format(fixed_id, moving_id)
    # load labels
    fixed_image = sitk.ReadImage(fixed_path)
    moved_image = sitk.ReadImage(moved_path)
    # load segmentations
    fixed_array = sitk.GetArrayFromImage(fixed_image)
    moved_array = sitk.GetArrayFromImage(moved_image)
    # compute metrics
    pair_dices = []
    for lab in labels_all:
        p = (fixed_array == lab)
        q = (moved_array == lab)
        pair_dices.append(dice_score(p, q))
    print(np.mean(pair_dices))
    all_dices.append(pair_dices)

# full results
all_dices = np.array(all_dices)
print(np.mean(all_dices, 0), np.std(all_dices, 1))
all_dices_0 = np.mean(all_dices, 1)
print(np.mean(all_dices_0), np.std(all_dices_0))

np.save("all_dices.npy", all_dices)


