import os
from glob import glob
import torch, sys
from torch.utils.data import Dataset
try:
    from .data_utils import pkload
except:
    from data_utils import pkload
import matplotlib.pyplot as plt
import random
import numpy as np
import nibabel as nib
import SimpleITK as sitk

class OASISNiftiDataset(Dataset):
    # keep last 50 images for validation
    val_split = 50
    def __init__(self, data_root, split='train', transforms=None):
        super().__init__()
        ### Deprecated in favor of neurite format
        # images = sorted(glob(os.path.join(data_root, 'imagesTr', '*.nii.gz')))
        # labels = sorted(glob(os.path.join(data_root, 'labelsTr', '*.nii.gz')))
        ### Only use the first 414 images for train and val
        images = sorted(glob(os.path.join(data_root, 'OASIS_OAS1*_MR1', 'aligned_norm.nii.gz')))[:414]
        labels = sorted(glob(os.path.join(data_root, 'OASIS_OAS1*_MR1', 'aligned_seg35.nii.gz')))[:414]
        self.split = split
        self.transforms = transforms
        if self.split == 'train':
            self.paths = list(zip(images[:-self.val_split], labels[:-self.val_split]))
        elif self.split == 'val':
            self.paths = list(zip(images[-self.val_split:], labels[-self.val_split:]))
        elif self.split == 'valmax':
            # basically keep all the images to use (in case training overfitting happens somehow inside the training set)
            self.paths = list(zip(images, labels))
        else:
            raise ValueError('Invalid split name')
        self.N = len(self.paths)
        print(f"Dataset has {self.N} samples with split {self.split}.")

    def __len__(self):
        # if val, reduce one image 
        return len(self.paths) - int(self.split != 'train')
    
    def one_hot(self, img, C):
        out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
        for i in range(C):
            out[i,...] = img == i
        return out

    def load(self, img_path, seg_path):
        # both are of size [1, H, W, D]
        ### Deprecated in favor of neurite format
        # img = nib.load(img_path).get_fdata().squeeze()   # returns (160, 224, 192)
        # seg = nib.load(seg_path).get_fdata().squeeze()
        # img = img.transpose(0, 2, 1)[None]
        # seg = seg.transpose(0, 2, 1)[None]
        img = nib.load(img_path).get_fdata().squeeze()[None]
        seg = nib.load(seg_path).get_fdata().squeeze()[None]
        # img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)).squeeze()[None]
        # seg = sitk.GetArrayFromImage(sitk.ReadImage(seg_path)).squeeze()[None]
        # print(img.shape)
        return img, seg 

    def __getitem__(self, index):
        fixed_img, fixed_seg = self.paths[index]
        if self.split == 'train':
            moving_img, moving_seg = self.paths[(np.random.randint(self.N - 1) + 1 + index) % self.N]
        else:
            moving_img, moving_seg = self.paths[(index + 1)]
        # for this dataset, x is the moving image
        fixed_img, fixed_seg = self.load(fixed_img, fixed_seg)
        moving_img, moving_seg = self.load(moving_img, moving_seg)
        # apply transforms
        fixed_img, fixed_seg = self.transforms([fixed_img, fixed_seg])
        moving_img, moving_seg = self.transforms([moving_img, moving_seg])
        fixed_img, fixed_seg = np.ascontiguousarray(fixed_img), np.ascontiguousarray(fixed_seg)
        moving_img, moving_seg = np.ascontiguousarray(moving_img), np.ascontiguousarray(moving_seg)
        fixed_img, fixed_seg, moving_img, moving_seg = [torch.from_numpy(x) for x in [fixed_img, fixed_seg, moving_img, moving_seg]]
        return moving_img, fixed_img, moving_seg, fixed_seg


class OASISBrainDataset(Dataset):
    def __init__(self, data_path, transforms):
        self.paths = data_path
        self.transforms = transforms

    def one_hot(self, img, C):
        out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
        for i in range(C):
            out[i,...] = img == i
        return out

    def __getitem__(self, index):
        path = self.paths[index]
        tar_list = self.paths.copy()
        tar_list.remove(path)
        random.shuffle(tar_list)
        tar_file = tar_list[0]
        x, x_seg = pkload(path)
        y, y_seg = pkload(tar_file)
        x, y = x[None, ...], y[None, ...]
        x_seg, y_seg = x_seg[None, ...], y_seg[None, ...]
        x, x_seg = self.transforms([x, x_seg])
        y, y_seg = self.transforms([y, y_seg])
        x = np.ascontiguousarray(x)  # [Bsize,channelsHeight,,Width,Depth]
        y = np.ascontiguousarray(y)
        x_seg = np.ascontiguousarray(x_seg)  # [Bsize,channelsHeight,,Width,Depth]
        y_seg = np.ascontiguousarray(y_seg)
        x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg)
        return x, y, x_seg, y_seg

    def __len__(self):
        return len(self.paths)


class OASISBrainInferDataset(Dataset):
    def __init__(self, data_path, transforms):
        self.paths = data_path
        self.transforms = transforms

    def one_hot(self, img, C):
        out = np.zeros((C, img.shape[1], img.shape[2], img.shape[3]))
        for i in range(C):
            out[i,...] = img == i
        return out

    def __getitem__(self, index):
        path = self.paths[index]
        x, y, x_seg, y_seg = pkload(path)
        x, y = x[None, ...], y[None, ...]
        x_seg, y_seg= x_seg[None, ...], y_seg[None, ...]
        x, x_seg = self.transforms([x, x_seg])
        y, y_seg = self.transforms([y, y_seg])
        x = np.ascontiguousarray(x)# [Bsize,channelsHeight,,Width,Depth]
        y = np.ascontiguousarray(y)
        x_seg = np.ascontiguousarray(x_seg)  # [Bsize,channelsHeight,,Width,Depth]
        y_seg = np.ascontiguousarray(y_seg)
        x, y, x_seg, y_seg = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(x_seg), torch.from_numpy(y_seg)
        return x, y, x_seg, y_seg

    def __len__(self):
        return len(self.paths)


if __name__ == '__main__':
    dataset = OASISNiftiDataset("/mnt/anon_data2/neurite-OASIS")
    for img, lab in dataset.paths:
        print(img, lab)