from data.glm_subject_data import GLMSubjectData
from datasets import register_dataset
import numpy as np

@register_dataset("glm_dataset")
class GLMDataset():
    def __init__(self, cfg, task_cfg=None) -> None:

        super().__init__()
        data_cfg = cfg
        data_cfg_copy = data_cfg.copy()
        self.cfg = data_cfg_copy
        pred_lst = self.get_pred_lst(data_cfg.predictors_list_path)
        s = GLMSubjectData(data_cfg_copy)
        self.subj_data = s
        X_dfs, idxs = self.get_all_predictors_dataframe(pred_lst)

    def get_pred_lst(self, pred_lst_path):
        with open(pred_lst_path, "r") as f:
            lines = f.readlines()
            lines = [l.strip() for l in lines]
        return lines

    def get_all_predictors_dataframe(self, pred_lst):
        pred_dfs = self.subj_data.words
        X_dfs, idxs = zip(*[self.get_predictors_dataframe(pred_lst, pred_df) for pred_df in pred_dfs])
        return X_dfs, idxs

    def get_predictors_dataframe(self, pred_lst, pred_df):
        pred_df = pred_df.copy() #To avoid warnings about setting on a copy of a slice
        pred_lst = np.array(pred_lst)
        pos_feature_idx = np.where(['pos-' in p for p in pred_lst])[0]
        if len(pos_feature_idx) > 0:
            pos_feature_lst = np.array([pred_lst[i].split('-')[1] for i in pos_feature_idx])
            pred_lst = np.delete(pred_lst, pos_feature_idx)
            pred_df = pred_df.loc[pred_df["pos"].str.lower().isin(pos_feature_lst)]
            assert not any(['prev' in p for p in pred_lst])
            assert not pred_df[pred_lst].mean().isnull().any()
            X_df = (pred_df[pred_lst] - pred_df[pred_lst].mean()).dropna()

            unit_variance = self.cfg.get("unit_variance", False)
            if unit_variance:
                X_df = X_df/(X_df.std() + EPSILON)
            X_df["pos"] = pred_df["pos"]
        else:
            print("POS must be present")
            import pdb; pdb.set_trace()
            
        samp_idxs = X_df.index
        X_df = X_df.reset_index(drop=True)
        #self.full_event_df = pred_df.loc[self.samp_idxs].reset_index(drop=True)
        import pdb; pdb.set_trace()
        return X_df, samp_idxs

