from pathlib import Path
import numpy as np
from omegaconf import DictConfig, OmegaConf
import cv2 as cv
import hydra
import logging
import os
from PIL import Image

logging.basicConfig(level = logging.INFO)
log = logging.getLogger(__name__)

def write_color_features(images, out_dir):
    all_features = []
    for img in images:
        unique, counts = np.unique(img.reshape(-1, 3), axis=0, return_counts=True)
        avg_color = unique[np.argmax(counts)]
        all_features.append(avg_color)
    all_features = np.stack(all_features)
    with open(os.path.join(out_dir,"features.npy"), "wb") as f:
        np.save(f, all_features)

def get_images(cfg):
    assert not (("images_file" in cfg) and ("images_dir" in cfg))

    if "images_file" in cfg:
        images_file = cfg.images_file

        with open(images_file, "rb") as f:
            images = np.load(f)

    if "images_dir" in cfg:
        images_dir = cfg.images_dir

        with open("/storage/user1/brain-diffuser/data/processed_data/subj01/nsd_test_stim_sub1.npy", "rb") as f:
            images = np.load(f)


        images = []
        num_test = 982
        for idx in range(num_test):
            img = cv.imread(os.path.join(images_dir, f'{idx}.png'))
            img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
            images.append(img)
        images = np.concatenate(images)
    print(images.shape)
    return images

@hydra.main(config_path="conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Write test features")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    out_dir = os.getcwd()
    log.info(f'Working directory {os.getcwd()}')
    if "out_dir" in cfg.exp:
        out_dir = cfg.exp.out_dir
    log.info(f'Output directory {out_dir}')

    #root_path = "/storage/user1/BrainBitsWIP/"
    #images_dir = os.path.join(root_path, 'data/nsddata_stimuli/test_images') #created by save_test_images

    images = get_images(cfg.exp)

    out_dir = os.path.join(out_dir, cfg.exp.feat_name)
    Path(out_dir).mkdir(exist_ok=True, parents=True)

    if cfg.exp.feat_name=="color":
        write_color_features(images, out_dir)

if __name__=="__main__":
    main()
