import hydra
from omegaconf import OmegaConf, DictConfig
import pandas as pd
import torch
import os
import pickle
from ntldm.networks import VAE  # Import your model definition here
from ntldm.train_vae import init_model, init_dataloaders
from ntldm.utils.dataset_utils import load_LDS_dataset

def load_model_deprecated(run_dir, device, ModelClass):
    # Load model (make sure to define your model structure in MyModel)
    model = ModelClass().to(device)
    model.load_state_dict(torch.load(run_dir, map_location=device))
    model.eval()
    return model
    
    
def load_config(config_path):
    # Load the YAML config file
    cfg = OmegaConf.load(config_path)
    return cfg

def join(x,y):
    # Join the os path
    return os.path.join(x,y)


def load_cfg_model_data(run_dir, device='cpu', model_name="model_end.pth", cfg_name="cfg.yaml"):
    # load the config file, the model and the dataloaders from the run_dir
    
    # load the config file cfg.yaml 
    cfg = load_config(join(run_dir, cfg_name))
    # load the model
    model = load_model(cfg, run_dir, device, model_name)
    # load dataset
    dataset = load_dataset(cfg)
    # load dataloaders
    dataloader, dataloader_valid, dataloader_test = load_dataloaders(cfg)

    return_dict = {"cfg": cfg, "model": model, "dataset": dataset, "dataloader": dataloader, "dataloader_valid": dataloader_valid, "dataloader_test": dataloader_test}
    return return_dict


def load_model(cfg, run_dir, device='cpu', model_name="model_end.pth"):
    # Load model (make sure to define your model structure in MyModel)
    model = init_model(cfg)
    model.load_state_dict(torch.load(join(run_dir, model_name),  map_location=device))
    return model


def load_train_test_datasets(file_path, indices=False):
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    if indices:
        return data["dataset"], data["train_indices"], data["valid_indices"], data["test_indices"]
    else:
        return data["dataset"]


def load_dataset(cfg,indices=True):
    """ load the dataset"""
    if cfg.dataset.name == "lds":
        filepath = os.path.join(cfg.dataset.filepath, cfg.dataset.filename)
        dataset = load_LDS_dataset(filepath)
    else:
        dataset = None
    return dataset

def load_dataloaders(cfg):
    """ load the dataloaders"""
    dataloader, dataloader_valid, dataloader_test = init_dataloaders(cfg)
    return dataloader, dataloader_valid, dataloader_test
    
def evaluate_autoencoder(return_dict, device='cpu'):
    # Evaluate the model
    model = return_dict["model"]
    model.eval()
    cfg = return_dict["cfg"]

    dataset, train_indices, valid_indices, test_indices = load_train_test_datasets(cfg, indices=True)
    
    # Evaluate the model
    model.eval()
    with torch.no_grad():
        data = dataset.samples[train_indices]
        data = data.to(device)
        # Decode the data and get the latent space
        reconstruction, mu, logvar = model.decode_mu(data)
        
        # save the true latents
        true_latents = torch.stack([dataset.latents[i] for i in train_indices]).to(device)
        


def eval(model, dataloader, device='cpu'):
    """ """
    # functions for 
    
    # do check if n_sample, n_seqlen, n_neurons are the same
    # if not do reshape  and raise warning
    pass
    

@hydra.main(config_path='../conf/vae/', config_name='eval_config')
def main(cfg: DictConfig):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    df = pd.read_csv(cfg.evaluate.csvpath)

    # if cfg.evaluate.network.name == 'VAE':
    #     ModelClass = VAE

    if cfg.evaluate.run_dir:
        # Load specific run if run_name is specified
        run_row = df[df['run_dir'] == cfg.evaluate.run_dir].iloc[0]
        return_dict = load_cfg_model_data(run_row['run_dir'], device)
            #model = load_model(row['filepath'], device, ModelClass)
        print(return_dict["model"])        #print(model)
    else:
        # Apply filters from configuration
        filtered_df = df
        for subgroup, filters in cfg.evaluate.filter.items():
            for key, value in filters.items():
                column_name = f"{subgroup}.{key}"  # Assumes the column name is constructed as subgroup.key
                filtered_df = filtered_df[filtered_df[column_name] == value]

        for _, run_row in filtered_df.iterrows():
            print(f"Evaluating model from {run_row['run_dir']}")
            return_dict = load_cfg_model_data(run_row['run_dir'], device)
            #model = load_model(row['filepath'], device, ModelClass)
            print(return_dict["model"])


if __name__ == "__main__":
    main()
