import os
import os.path as osp
import numpy as np
from sklearn.preprocessing import StandardScaler
import torch
from torch.utils.data import Dataset, Subset
import random

def get_idx_splits(data_len, frac, seed=0):
    # returns train_idcs, val_idcs
    idx_list = [i for i in range(data_len)]
    random.Random(seed).shuffle(idx_list)
    split_idx = int(frac * data_len)
    return idx_list[split_idx:], idx_list[:split_idx]

class fMRI2latent(Dataset):
    def __init__(self, subject, regions, target, train=True):

        # load the inputs

        if osp.exists(f'{subject}/preprocessed_data/{target}_fmri_feats_{"tr" if train else "te"}.npy'):
            X = np.load(f'{subject}/preprocessed_data/{target}_fmri_feats_{"tr" if train else "te"}.npy')
            Y = np.load(f'{subject}/preprocessed_data/{target}_feats_{"tr" if train else "te"}.npy')
            self.X = torch.from_numpy(X)
            self.Y = torch.from_numpy(Y)

        else:
            X = []
            mridir = f'../../../mrifeat/{subject}'
            assert osp.exists(mridir), 'point to mrifeat/subject'
            for croi in regions:
                cX = np.load(f'{mridir}/{subject}_{croi}_betas_tr.npy').astype("float32")
                X.append(cX)
            X = np.hstack(X)
            scaler = StandardScaler(with_mean=True, with_std=True)
            scaler.fit(X)

            if not train:
                # replace X with X_te (after getting normalization parameters above)
                X_te = []
                for croi in regions:
                    cX_te = np.load(f'{mridir}/{subject}_{croi}_betas_ave_te.npy').astype("float32")
                    X_te.append(cX_te)
                X = np.hstack(X_te)
            X = scaler.transform(X)
            os.makedirs(f'{subject}/preprocessed_data', exist_ok=True)
            np.save(f'{subject}/preprocessed_data/{target}_fmri_feats_{"tr" if train else "te"}.npy', X)
            self.X = torch.from_numpy(X)

            # load the targets
            featdir = '../../../nsdfeat/subjfeat/'
            assert osp.exists(mridir), 'point to nsdfeat/subjfeat'
            if train:
                Y = np.load(f'{featdir}/{subject}_each_{target}_tr.npy').astype("float32").reshape([X.shape[0],-1])
            else:
                Y = np.load(f'{featdir}/{subject}_ave_{target}_te.npy').astype("float32").reshape([X.shape[0],-1])
            np.save(f'{subject}/preprocessed_data/{target}_feats_{"tr" if train else "te"}.npy', Y)
            self.Y = torch.from_numpy(Y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

def setup_dataset(subject, regions, target, split_fraction=0.1):
    ds = fMRI2latent(subject, regions, target)
    ds_te = fMRI2latent(subject, regions, target, train=False)
    tr_idcs, va_idcs = get_idx_splits(len(ds), split_fraction)
    ds_tr = Subset(ds, tr_idcs)
    ds_va = Subset(ds, va_idcs)
    return ds_tr, ds_va, ds_te