#python3 -m plotting.write_mean_stats +data=finetuning_template +exp=make_mean_stats ++data.duration=0.5 ++data.interval_duration=1.0 ++data.name="sentence_position_finetuning" +test=all_trials ++data.delta=0 ++data.movie_transcripts_dir="/storage/czw/seeg_decoding/updated_word_features" +plot=make_mean_stats
import numpy as np
from omegaconf import DictConfig, OmegaConf
import hydra
import logging
import os
from data.electrode_selection import get_clean_electrodes
import json
from pathlib import Path
import pandas as pd
from data.subject_data import SubjectData

def get_mean_stats(idxs, neural_data, data_cfg):
    assert len(neural_data.shape)==3
    data = neural_data[:,idxs]
    mean = abs(data).mean(axis=1)
    #mean = (data).mean(axis=1)
    return mean.mean()

def get_peak_stats(idxs, neural_data, data_cfg):
    assert len(neural_data.shape)==3
    data = neural_data[:,idxs]
    mean = data.mean(axis=1)
    delta = data_cfg.delta#delta is in seconds
    argmax_idx = mean.argmax()
    peak = mean.max()
    min_peak = mean.min()
    peak_time = argmax_idx/2048 + delta
    return peak, peak_time, min_peak 

def get_all_stats(words_df, neural_data, data_cfg):
    assert len(words_df) - 1 == words_df.index[-1]

    control = "gpt2_surprisal"
    high = np.percentile(words_df[control], 75)
    low = np.percentile(words_df[control], 25)
    assert high > low 
    control_high_idxs = words_df[words_df[control] >= high].index
    control_low_idxs = words_df[words_df[control] <= low].index
    assert len(control_high_idxs) > 2 and len(control_low_idxs) > 2

    results = {}
    results["gpt2_surprisal_high_peak"], results["gpt2_surprisal_high_time"], _ = get_peak_stats(control_high_idxs, neural_data, data_cfg)
    results["gpt2_surprisal_low_peak"], results["gpt2_surprisal_low_time"], _ = get_peak_stats(control_low_idxs, neural_data, data_cfg)

    verb_idxs = words_df[words_df.pos=="VERB"].index
    noun_idxs = words_df[words_df.pos=="NOUN"].index

    onset_idxs = words_df[words_df.is_onset==1].index
    #offset_idxs = words_df[words_df.is_offset==1].index
    #midset_idxs = words_df[(words_df["is_onset"]==0) & (words_df["is_offset"]==0)].index

    #make sure midset idxs have an equal number of nouns and verbs for each index
    max_idx = int(max(words_df.idx_in_sentence))
    midset_idxs = []
    for idx in (1,max_idx):
        idx_idxs = words_df[words_df.idx_in_sentence==idx].index
        noun_idx_idxs = np.intersect1d(noun_idxs, idx_idxs)
        verb_idx_idxs = np.intersect1d(verb_idxs, idx_idxs)
        min_len = min(len(noun_idx_idxs), len(verb_idx_idxs))
        midset_idxs.append(np.array(noun_idx_idxs[:min_len]))
        midset_idxs.append(np.array(verb_idx_idxs[:min_len]))
    midset_idxs = np.concatenate(midset_idxs)

    #do the same for offsets
    offset_idxs = []
    for idx in (1,max_idx):
        idx_idxs = words_df[(words_df.idx_in_sentence==idx) & words_df.is_offset==1].index
        noun_idx_idxs = np.intersect1d(noun_idxs, idx_idxs)
        verb_idx_idxs = np.intersect1d(verb_idxs, idx_idxs)
        min_len = min(len(noun_idx_idxs), len(verb_idx_idxs))
        offset_idxs.append(np.array(noun_idx_idxs[:min_len]))
        offset_idxs.append(np.array(verb_idx_idxs[:min_len]))
    offset_idxs = np.concatenate(offset_idxs)

    results["verb_peak"], results["verb_time"], _ = get_peak_stats(verb_idxs, neural_data, data_cfg)
    results["noun_peak"], results["noun_time"], _ = get_peak_stats(noun_idxs, neural_data, data_cfg)
    results["onset_peak"], results["onset_time"], _ = get_peak_stats(onset_idxs, neural_data, data_cfg)
    results["offset_peak"], results["offset_time"], _ = get_peak_stats(offset_idxs, neural_data, data_cfg)
    results["midset_peak"], results["midset_time"], _ = get_peak_stats(midset_idxs, neural_data, data_cfg)

    verb_onset_idxs = np.intersect1d(verb_idxs, onset_idxs)
    noun_onset_idxs = np.intersect1d(noun_idxs, onset_idxs)
    verb_offset_idxs = np.intersect1d(verb_idxs, offset_idxs)
    noun_offset_idxs = np.intersect1d(noun_idxs, offset_idxs)
    verb_midset_idxs = np.intersect1d(verb_idxs, midset_idxs)
    noun_midset_idxs = np.intersect1d(noun_idxs, midset_idxs)

    results["verb_onset_peak"], results["verb_onset_time"], _ = get_peak_stats(verb_onset_idxs, neural_data, data_cfg)
    results["noun_onset_peak"], results["noun_onset_time"], _ = get_peak_stats(noun_onset_idxs, neural_data, data_cfg)
    results["verb_offset_peak"], results["verb_offset_time"], _ = get_peak_stats(verb_offset_idxs, neural_data, data_cfg)
    results["noun_offset_peak"], results["noun_offset_time"], _ = get_peak_stats(noun_offset_idxs, neural_data, data_cfg)
    results["verb_midset_peak"], results["verb_midset_time"], _ = get_peak_stats(verb_midset_idxs, neural_data, data_cfg)
    results["noun_midset_peak"], results["noun_midset_time"], _ = get_peak_stats(noun_midset_idxs, neural_data, data_cfg)

    results["all_peak"], results["all_time"], results["all_peak_min"] = get_peak_stats(list(range(neural_data.shape[1])), neural_data, data_cfg)

    mean_activity = neural_data[0].mean(axis=0).astype('float32')

    for i in range(8):
        idx_in_sentence_idxs = words_df[words_df.idx_in_sentence==i].index
        results[f'{i}_mean'] = get_mean_stats(idx_in_sentence_idxs, neural_data, data_cfg)
    return results, mean_activity

log = logging.getLogger(__name__)
@hydra.main(config_path="../conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Write all mean stats")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))

    data_cfg = cfg.data
    all_test_results = {}

    split_path = cfg.plot.split_path 
    with open(split_path, "r") as f:
        test_splits = json.load(f)

    all_stats, all_mean_activity = [], []
    for subj in test_splits:
        electrodes = get_clean_electrodes(subj)

        for elec in electrodes:
            elec = "T1b2" #TODO debug
            subj = "sub3" #TODO debug
            log.info(f"Subject {subj}")
            data_cfg.subject = subj

            data_cfg.electrodes = [elec]
            data_cfg.brain_runs = test_splits[subj]
            #data_cfg.brain_runs = ["trial000"] #TODO debug

            subj_data = SubjectData(data_cfg)
            words_df = subj_data.words.reset_index(drop=True)
            neural_data = subj_data.neural_data
            stats, mean_activity = get_all_stats(words_df, neural_data, data_cfg)
            stats["ID"] = f'{subj}_{elec}'
            all_stats.append(stats)    
            all_mean_activity.append(mean_activity)
            import pdb; pdb.set_trace()

    out_dir = cfg.exp.out_dir
    out_path = os.path.join(out_dir, "all_stats.csv")
    Path(out_dir).mkdir(exist_ok=True, parents=True)
    df = pd.DataFrame(all_stats)
    df.to_csv(out_path)

    out_path = os.path.join(out_dir, "all_mean_activity.npy")
    all_means = np.stack(all_mean_activity)
    np.save(out_path, all_means)

if __name__ == "__main__":
    main()
