from omegaconf import DictConfig, OmegaConf
import hydra
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import os
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm as tqdm
import argparse
import logging
import json
import sys
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import models

logging.basicConfig(level = logging.INFO)

log = logging.getLogger(__name__)
USER = os.getenv('USER')
if USER == "user1":
    SAVE_ROOT_PATH = Path(f'/storage/user1/BrainBitsWIP/data/predicted_features/')
elif USER == "user1":
    SAVE_ROOT_PATH = Path(f'/storage/user1/projects/brainbits/BrainBitsWIP/data/predicted_features/')
else:
    raise ValueError(f"Unknown user {USER}")
BD_ROOT_PATH = Path('/storage/user1/brain-diffuser')


class fMRI2latent(Dataset):
    def __init__(self, fmri_data, vdvae_embeds):
        self.fmri_data = fmri_data
        self.vdvae_embeds = vdvae_embeds

    def __len__(self):
        return len(self.fmri_data)

    def __getitem__(self, idx):
        return {"inputs": torch.FloatTensor(self.fmri_data[idx]), 
                "vdvae_targets": torch.FloatTensor(self.vdvae_embeds[idx]),
               }

class BottleneckLinear(nn.Module):
    def __init__(self, input_size, bottleneck_size, d_vdvae, cfg, embed_w=None, multi_gpu=False):
        super().__init__()

        self.fmri2embed = nn.Sequential(nn.Linear(input_size, bottleneck_size, bias=False),
                                        #nn.Linear(bottleneck_size, bottleneck_size),
                                        #torch.nn.ReLU(),
                                        #nn.Linear(bottleneck_size, bottleneck_size),
                                       )
        #self.fmri2embed = nn.Identity() #TODO
        if cfg.pca_preload:
            self.fmri2embed[0].weight = torch.nn.Parameter(torch.FloatTensor(embed_w))
        else:
            self.fmri2embed = nn.Sequential(nn.Linear(input_size, bottleneck_size, bias=True),)

        self.vdvae_embed = nn.Linear(bottleneck_size, d_vdvae)

    def forward(self, fmri_inputs):
        bottleneck_mapping = self.fmri2embed(fmri_inputs)
        vdvae_mapping = self.vdvae_embed(bottleneck_mapping)
        #vdvae_mapping = self.vdvae_embed(fmri_inputs)
        return vdvae_mapping

def get_loss(criterion, vdvae_preds, vdvae_targets, batch, reg_cfg, n_batch):
    vdvae_loss = criterion(vdvae_preds, vdvae_targets)
    loss = vdvae_loss
    return loss, vdvae_loss

def get_eval_loss(criterion, model, val_loader, reg_cfg):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for batch in tqdm(val_loader):
            inputs = batch["inputs"].to(reg_cfg.device)
            n_batch = inputs.shape[0]
            vdvae_preds = model(inputs) 
            vdvae_targets = batch["vdvae_targets"].to(reg_cfg.device)
            loss, vdvae_loss = get_loss(criterion, vdvae_preds, vdvae_targets, batch, reg_cfg, n_batch)
            total_loss += loss.item()
    return total_loss/len(val_loader)

def scale_preds(vdvae_preds, train_stats):
    train_mean, train_std = train_stats
    vdvae_preds_arr = vdvae_preds.detach()
    epsilon = 0.0001
    std_norm_test_latent = (vdvae_preds - torch.mean(vdvae_preds_arr,axis=0)) / (torch.nan_to_num(torch.std(vdvae_preds_arr,axis=0),nan=epsilon))
    pred_latents = std_norm_test_latent * torch.FloatTensor(train_std).to(vdvae_preds.device) + torch.FloatTensor(train_mean).to(vdvae_preds.device)
    return pred_latents

def train_linear_mapping(model, train_loader, val_loader, reg_cfg, train_stats):
    if reg_cfg.optim == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=reg_cfg.lr, momentum=0.0, weight_decay=0.001)
    elif reg_cfg.optim == "Adam":
        optimizer = optim.AdamW(model.parameters(), lr=reg_cfg.lr, weight_decay=0.001)
    else:
        print("no optim")

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
    min_eval_loss = 100
    best_model = None
    criterion = nn.MSELoss()
    named_parameters = list(model.named_parameters())

    freeze_embed = False
    lr_1, lr_2 = reg_cfg.lr, reg_cfg.lr
    for epoch in range(reg_cfg.n_epochs):
        if epoch%20==0:
            freeze_embed = not freeze_embed
            if freeze_embed:
                lr_1 = optimizer.param_groups[0]['lr']
                optimizer = optim.AdamW(model.parameters(), lr=lr_2, weight_decay=0.001)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
            else:
                lr_2 = optimizer.param_groups[0]['lr']
                optimizer = optim.AdamW(model.parameters(), lr=lr_1, weight_decay=0.001)
                scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        if freeze_embed:
            for name, param in model.named_parameters():
                if 'fmri2embed' in name:
                    param.requires_grad = False
                elif 'vdvae_embed' in name:
                    param.requires_grad = True
        else:
            for name, param in model.named_parameters():
                if 'fmri2embed' in name:
                    param.requires_grad = True
                elif 'vdvae_embed' in name:
                    param.requires_grad = False

        with tqdm(total=len(train_loader)) as bar:
            bar.set_description(f"Epoch {epoch}")
            train_loss, train_vdvae_loss, train_text_loss = 0, 0, 0
            for batch in train_loader:
                inputs = batch["inputs"].to(reg_cfg.device)
                n_batch = inputs.shape[0]
                #targets = batch["targets"].to(reg_cfg.device) #TODO
                #targets = batch["targets"].cuda(1)
                optimizer.zero_grad()
                vdvae_preds = model(inputs)
                vdvae_targets = batch["vdvae_targets"].to(reg_cfg.device)

                #vdvae_preds = scale_preds(vdvae_preds, train_stats)
                loss, vdvae_loss = get_loss(criterion, vdvae_preds, vdvae_targets, batch, reg_cfg, n_batch)
                loss.backward()
                ##print(loss.item())
                optimizer.step()
                bar.set_postfix({"v":float(vdvae_loss)})
                bar.update()
                train_loss += float(loss)
                train_vdvae_loss += float(vdvae_loss)
                #if epoch==18:
                #    import pdb; pdb.set_trace()
            
            avg_loss = train_loss/len(train_loader)
            avg_vdvae_loss = train_vdvae_loss/len(train_loader)

            eval_loss = get_eval_loss(criterion, model, val_loader, reg_cfg)
            bar.set_postfix({"eval": eval_loss, "mse":avg_loss, "v": avg_vdvae_loss})
        if eval_loss < min_eval_loss:
            min_eval_loss = eval_loss
            best_model = model
        scheduler.step(avg_loss)
    #return model#TODO
    return best_model#TODO

def eval_model(model, test_loader, device):
    model.eval()
    with torch.no_grad():
        all_vdvae_preds = []
        for batch in tqdm(test_loader):
            inputs = batch["inputs"].to(device)
            vdvae_preds = model(inputs)
            all_vdvae_preds.append(vdvae_preds)
        all_vdvae_preds = torch.cat(all_vdvae_preds)
    all_vdvae_preds = all_vdvae_preds.cpu().detach().numpy()
    return all_vdvae_preds

def scale_latents(pred_test_latent, train_latents):
    std_norm_test_latent = (pred_test_latent - np.mean(pred_test_latent,axis=0)) / np.std(pred_test_latent,axis=0)
    pred_latents = std_norm_test_latent * np.std(train_latents,axis=0) + np.mean(train_latents,axis=0)
    return pred_latents

def get_vdvae_targets(sub):
    log.info("Getting VDVAE targets")

    #get latent targets
    nsd_path = 'data/extracted_features/subj{:02d}/nsd_vdvae_features_31l.npz'.format(sub)
    nsd_features = np.load(BD_ROOT_PATH / nsd_path)

    train_latents = nsd_features['train_latents']
    test_latents = nsd_features['test_latents']

    return train_latents, test_latents

def get_fmri_inputs(sub):
    #get fmri inputs
    log.info("Getting fMRI inputs")
    train_path = 'data/processed_data/subj{:02d}/nsd_train_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
    train_fmri = np.load(BD_ROOT_PATH / train_path)
    test_path = 'data/processed_data/subj{:02d}/nsd_test_fmriavg_nsdgeneral_sub{}.npy'.format(sub,sub)
    test_fmri = np.load(BD_ROOT_PATH / test_path)

    train_fmri = train_fmri/300
    test_fmri = test_fmri/300

    norm_mean_train = np.mean(train_fmri, axis=0)
    norm_scale_train = np.std(train_fmri, axis=0, ddof=1)
    train_fmri = (train_fmri - norm_mean_train) / norm_scale_train
    test_fmri = (test_fmri - norm_mean_train) / norm_scale_train
    return train_fmri, test_fmri

def save_preds(arr, sub, bottleneck_size, out_name):
    save_path_dir = SAVE_ROOT_PATH / f'subj_{sub}/bbits_{bottleneck_size}/'
    save_path_dir.mkdir(parents=True, exist_ok=True)
    np.save(save_path_dir / f"{out_name}.npy", arr)

def train_all(sub, bottleneck_size, train_fmri, test_fmri, reg_cfg):
    vdvae_embeds_train, vdvae_embeds_test = get_vdvae_targets(sub)

    #vdvae_embeds_train = (vdvae_embeds_train - np.mean(vdvae_embeds_train, axis=0))/np.std(vdvae_embeds_train, axis=0)#TODO scaling happens here
    n_train, d_vdvae = vdvae_embeds_train.shape
    n_test, _, = vdvae_embeds_test.shape

    val_split = 0.15 #TODO hardcode

    all_train_data = fMRI2latent(train_fmri, vdvae_embeds_train)
    train_idx, val_idx = train_test_split(list(range(len(all_train_data))), test_size=val_split)
    
    train_input_arr = train_fmri[train_idx]
    pca = PCA(n_components=bottleneck_size)
    pca.fit(train_input_arr)
    pca_components = pca.components_

    train_latent_arr = vdvae_embeds_train[train_idx]
    train_std = np.std(train_latent_arr, axis=0)
    train_mean = np.mean(train_latent_arr, axis=0)
    train_stats = (train_mean, train_std)

    train_data = Subset(all_train_data, train_idx)
    val_data = Subset(all_train_data, val_idx)
    train_loader = DataLoader(train_data, batch_size=reg_cfg.batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=reg_cfg.batch_size, shuffle=True)

    #bottleneck_size = d_vdvae  
    model = BottleneckLinear(train_fmri.shape[-1], bottleneck_size, d_vdvae, reg_cfg, embed_w=pca_components)

    model = model.to(reg_cfg.device)
    #if device=="cuda":
    #    model= nn.DataParallel(model)
    log.info("Training fMRI2latent mapping")

    model = train_linear_mapping(model, train_loader, val_loader, reg_cfg, train_stats)
    
    log.info("fMRI2latent test evaluation")
    test_data = fMRI2latent(test_fmri, vdvae_embeds_test)
    test_loader = DataLoader(test_data, batch_size=reg_cfg.batch_size, shuffle=False)

    vdvae_preds = eval_model(model, test_loader, reg_cfg.device)#, test_fmri)

    scaled_vdvae_preds = scale_latents(vdvae_preds, vdvae_embeds_train)

    save_preds(scaled_vdvae_preds,sub, bottleneck_size, "vdvae_preds")

    save_preds(vdvae_preds,sub, bottleneck_size,   "unscaled_vdvae_preds")

@hydra.main(config_path="conf")
def main(cfg: DictConfig) -> None:
    log.info(f"Run testing for all electrodes in all test_subjects")
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    out_dir = os.getcwd()
    log.info(f'Working directory {os.getcwd()}')
    if "out_dir" in cfg.exp:
        out_dir = cfg.exp.out_dir
    log.info(f'Output directory {out_dir}')

    sub = cfg.exp["sub"]

    train_fmri, test_fmri = get_fmri_inputs(sub)
    
    bottleneck_sizes = cfg.exp["bottlenecks"]
    reg_cfg = cfg.exp.reg
    for bottleneck_size in bottleneck_sizes:
        train_all(sub, bottleneck_size, train_fmri, test_fmri, reg_cfg)

if __name__=="__main__":
    # _debug = '''train.py +exp=latent_reg ++exp.bottlenecks=[5] ++exp.reg.batch_size=128 ++exp.reg.n_epochs=1 ++exp.reg.optim="SGD" ++exp.reg.device="cpu"'''
    # sys.argv = _debug.split(" ")
    main()

