from glob import glob
import h5py
import copy
from omegaconf import DictConfig, OmegaConf
import hydra
import torch
import numpy as np
import os
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm as tqdm
import argparse
import logging
import json
import sys
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

logging.basicConfig(level = logging.INFO)

log = logging.getLogger(__name__)

def get_effective_dim(contexts):
    pca = PCA()
    reduced = pca.fit_transform(contexts)

    ratios = pca.explained_variance_ratio_
    dim = 0
    dist = {i:0 for i in range(len(ratios))}
    while dim < len(ratios):
        percent = np.sum(ratios[:dim])
        if percent > 0.95:
            break
        dist[dim] = percent.item()
        dim += 1
    return dim, dist

@hydra.main(config_path="../conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Run decoding on bottleneck features")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    out_dir = os.getcwd()
    log.info(f'Working directory {os.getcwd()}')
    if "out_dir" in cfg.exp:
        out_dir = cfg.exp.out_dir
    log.info(f'Output directory {out_dir}')

    bbits_dir = cfg.exp.bbits_dir
    subject = cfg.exp.subject

    root_path = "/storage/user/semantic-decoding-brainbits"
    exps = ["imagined_speech",  "perceived_movie",  "perceived_multispeaker",  "perceived_speech"]
    all_data = []
    for exp in exps:
        tasks = glob(os.path.join(root_path, "data_test", "test_response", subject, exp, bbits_dir, "*"))
        for task_path in tasks:
            task_name = Path(task_path).stem
            #log.info(f"writing {task_name}")
            hf = h5py.File(task_path, "r")
            data = np.nan_to_num(hf["data"][:])
            all_data.append(data)
    all_data = np.concatenate(all_data)
    dim, dist = get_effective_dim(all_data)

    output_path = os.path.join(cfg.exp.output_path, subject, bbits_dir)
    Path(output_path).mkdir(exist_ok=True, parents=True)
    output_path = os.path.join(output_path, "results.json")
    results = {"dim": dim}
    with open(output_path, "w") as f:
        print(output_path)
        json.dump(results, f)

if __name__=="__main__":
    main()

