# %%
import argparse
import os
import sys

import accelerate
import lovely_tensors as lt
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import torch
import torch.nn as nn
import wandb
import yaml
from diffusers.optimization import get_scheduler
from omegaconf import OmegaConf
from tqdm.auto import tqdm

from ntldm.data.latent_attractor import get_attractor_dataset
from ntldm.networks import AutoEncoder, CountWrapper
from ntldm.utils.plotting_utils import *
from ntldm.losses import (
    latent_regularizer,
    latent_regularizer_v2,
)

lt.monkey_patch()
matplotlib.rc_file("matplotlibrc")


# %%

cfg_yaml = """
model:
  C_in: 128
  C: 256
  C_latent: 32
  kernel: s4
  num_blocks: 6
  num_blocks_decoder: 0
  num_lin_per_mlp: 2
  bidirectional: False # important!
dataset:
  system_name: phoneme
  datapath: data/phoneme/competitionData
  max_seqlen: 512
training:
  lr: 0.001
  num_epochs: 400
  num_warmup_epochs: 20
  batch_size: 256
  random_seed: 42
  precision: bf16
  latent_beta: 0.001
  latent_td_beta: 0.1
  tk_k: 5
  mask_prob: 0.20
  latent_reg_version: v2
exp_name: autoencoder-count_s4-phoneme_v2loss
"""


# omegaconf from yaml
cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))

parser = argparse.ArgumentParser()

parser.add_argument(
    "--update_conf",
    nargs="*",
    help="Updates to the configuration in the form of key=value pairs",
    default=[],
)
cli_args = parser.parse_args()

# read any updates from the command line
for update in cli_args.update_conf:
    key, value = update.split("=")
    # dynamically determine the type of value
    try:
        # convert numerical values from string
        value = eval(value)
    except:
        pass  # keep value as string if not numerical
    OmegaConf.update(
        cfg, key, value, force_add=True
    )  # update config with new value, adding key if necessary

# update exp_name
for update in cli_args.update_conf:
    cfg.exp_name = cfg.exp_name + "_" + update


print(OmegaConf.to_yaml(cfg))

# write to file
os.makedirs("conf/sweeps_new", exist_ok=True)
os.makedirs("exp/new", exist_ok=True)
with open("conf/sweeps_new/Phoneme_" + cfg.exp_name + ".yaml", "w") as f:
    f.write(OmegaConf.to_yaml(cfg))

save_path = "exp/new/" + cfg.exp_name

# %%import math
from ntldm.data.phoneme import get_phoneme_dataloaders

# set seed
torch.manual_seed(cfg.training.random_seed)
np.random.seed(cfg.training.random_seed)

train_dataloader, val_dataloader, test_dataloader = get_phoneme_dataloaders(
    cfg.dataset.datapath,
    batch_size=cfg.training.batch_size,
    max_seqlen=cfg.dataset.max_seqlen,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


ae = AutoEncoder(
    C_in=cfg.model.C_in,
    C=cfg.model.C,
    C_latent=cfg.model.C_latent,
    L=cfg.dataset.max_seqlen,
    kernel=cfg.model.kernel,
    num_blocks=cfg.model.num_blocks,
    num_blocks_decoder=cfg.model.get("num_blocks_decoder", cfg.model.num_blocks),
    num_lin_per_mlp=cfg.model.get("num_lin_per_mlp", 2),  # default 2
    bidirectional=cfg.model.get("bidirectional", False),  # default is false for phoneme
)

print(
    "Number of params",
    sum(p.numel() for p in ae.parameters() if p.requires_grad) / 1e6,
    "M",
)

ae = CountWrapper(ae, use_sin_enc=cfg.model.get("use_sin_enc", False))
print(ae)

ae = ae.to(device)
optimizer = torch.optim.AdamW(
    ae.parameters(), lr=cfg.training.lr
)  # default wd=0.01 for now

num_batches = len(train_dataloader)
lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=num_batches
    * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs
    num_training_steps=num_batches
    * cfg.training.num_epochs
    * 1.5,  # total number of steps
    # num_training_steps=cfg.training.num_epochs # AS changed
)

# check if signal length is power of 2
if cfg.dataset.max_seqlen & (cfg.dataset.max_seqlen - 1) != 0:
    cfg.training.precision = "no"  # torch.fft doesnt support half if L!=2^x


#

# Accelerate setuo

accelerator = accelerate.Accelerator(
    mixed_precision=cfg.training.precision,
    log_with="wandb",
)


(
    ae,
    optimizer,
    lr_scheduler,
    train_dataloader,
    val_dataloader,
    test_dataloader,
) = accelerator.prepare(
    ae,
    optimizer,
    lr_scheduler,
    train_dataloader,
    val_dataloader,
    test_dataloader,
)


# %%

criterion_poisson = nn.PoissonNLLLoss(log_input=False, full=True, reduction="none")


def compute_val_loss(net, dataloader):
    net.eval()
    poisson_loss_total = 0
    rates_loss_total = 0
    batch_count = 0

    for batch in dataloader:
        signal = batch["signal"]
        signal_mask = batch["mask"].cpu()
        with torch.no_grad():
            output_rates = net(signal)[0].cpu()

        signal = signal.cpu()  # move signal to cpu

        # compute pointwise l2 loss
        poisson_loss = criterion_poisson(output_rates, signal)
        poisson_loss = poisson_loss * signal_mask

        poisson_loss_total += poisson_loss.mean().item()
        batch_count += 1

    # compute average losses over all batches
    avg_poisson_loss = poisson_loss_total / batch_count * cfg.training.mask_prob
    print(
        f"Validation loss: {avg_poisson_loss:.4f}, mask_prob {cfg.training.mask_prob}"
    )

    fig, ax = plt.subplots(2, 1, figsize=(10, 2), dpi=300)
    for row in range(2):  # plot channels 0 and 71
        ax[row].plot(output_rates[0, 92 * (row)].cpu().clip(0, 3).numpy(), label="pred")
        ax[row].plot(
            batch["signal"][0, 92 * (row)].cpu().clip(0, 3).numpy(),
            label="spikes",
            alpha=0.5,
            color="grey",
        )
        plt.legend()
    wandb.log({"val_rates": wandb.Image(fig)})
    plt.close(fig)

    return avg_poisson_loss


# %%

from ntldm.utils.plotting_utils import cm2inch
from einops import rearrange


def plot_rate_traces_real(model, dataloader, figsize=(12, 5), idx=0):
    model.eval()
    for batch in dataloader:
        signal = batch["signal"]
        signal_mask = batch["mask"].cpu()
        with torch.no_grad():
            output_rates = model(signal)[0].cpu()

        signal = signal.cpu()  # move signal to cpu
        break

    # select 2 channels that are 100%ile and 30%ile in the firing rates across the minibatch

    mean_firing_rates = signal.mean(0).mean(0)
    channels = torch.argsort(mean_firing_rates, descending=True)[:2]

    fig, ax = plt.subplots(1, len(channels), figsize=cm2inch(figsize), dpi=150)

    for i, channel in enumerate(channels):
        # print(batch["signal"][idx, channel])
        L_actual = int(signal_mask[idx, channel].sum().item())
        # print('L_actual: ', L_actual)
        L = batch["signal"][idx, channel].shape[0]
        ax[i].vlines(
            torch.arange(L_actual),
            torch.zeros(L_actual),
            torch.ones(L_actual)
            * output_rates[idx, channel, :L_actual].cpu().max().item(),
            color="black",
            alpha=np.min(
                np.stack(
                    (
                        np.ones(L_actual),
                        batch["signal"][idx, channel, :L_actual].cpu().numpy() * 0.1,
                    ),
                    axis=1,
                ),
                axis=1,
            ),
        )
        ax[i].plot(
            output_rates[idx, channel, :L_actual].cpu().numpy(),
            label="pred",
            color="red",
        )
        ax[i].set_title(f"channel {channel}")

    ax[-1].legend()

    fig.suptitle("rate traces for channels")
    fig.tight_layout()
    return fig
    # plt.show()


def imshow_rates_real(model, dataloader, figsize=(12, 5), idx=0):
    model.eval()
    for batch in dataloader:
        signal = batch["signal"]
        signal_mask = batch["mask"].cpu()
        with torch.no_grad():
            output_rates = model(signal)[0].cpu()

        signal = signal.cpu()  # move signal to cpu
        break

    fig, ax = plt.subplots(1, 2, figsize=cm2inch(figsize), dpi=150)

    L_actual = int(batch["mask"][idx, 0].sum().item())

    im1 = ax[0].imshow(
        output_rates[idx, :, :L_actual].cpu().numpy(),
        label="rates",
        aspect="auto",
        cmap="Greys",
    )
    im2 = ax[1].imshow(
        signal[idx, :, :L_actual].cpu().numpy(),
        label="rates",
        aspect="auto",
        cmap="Greys",
    )
    plt.colorbar(im1, ax=ax[0])
    plt.colorbar(im2, ax=ax[1])

    # ax.set_title(f"channel {channel}")

    ax[-1].legend()
    ax[0].set_title("Inferred rates")
    ax[1].set_title("GT Spikes")

    fig.suptitle(f"infeered rates, idx {idx}")
    fig.tight_layout()
    return fig
    # plt.show()


def compute_latents(model, dataloader):
    model.eval()
    latents = []
    signal_masks = []
    for batch in dataloader:
        signal = batch["signal"]
        signal_mask = batch["mask"]
        with torch.no_grad():
            output_rates, z = model(signal)
            z = z.cpu()
        latents.append(z)
        signal_masks.append(signal_mask.cpu())

    return {
        "latents": torch.cat(latents, 0),
        "signal_masks": torch.cat(signal_masks, 0),
    }


def reconstruct_spikes(model, dataloader):
    model.eval()
    latents = []
    spikes = []
    rec_spikes = []
    signal_masks = []
    for batch in dataloader:
        signal = batch["signal"]
        signal_mask = batch["mask"]
        with torch.no_grad():
            output_rates, z = model(signal)
            z = z.cpu()
        latents.append(z)
        spikes.append(signal.cpu())
        rec_spikes.append(torch.poisson(output_rates.cpu()) * signal_mask.cpu())
        signal_masks.append(signal_mask.cpu())

    return {
        "latents": torch.cat(latents, 0),
        "spikes": torch.cat(spikes, 0),
        "rec_spikes": torch.cat(rec_spikes, 0),
        "signal_masks": torch.cat(signal_masks, 0),
    }


def plot_corrcoef(rec_dict, figsize=cm2inch(12, 4)):
    # cross-correlation between neurons

    # corrcoef_real = np.corrcoef(rec_dict['spikes'][:,:].sum(2).numpy(), rowvar=False)
    # corrcoef_rec = np.corrcoef(rec_dict['rec_spikes'][:,:].sum(2).numpy(), rowvar=False)

    real_spikes = [
        rec_dict["spikes"][i, :, : int(rec_dict["signal_masks"][i, 0].sum().item())]
        for i in range(len(rec_dict["spikes"]))
    ]
    real_spikes = torch.cat(real_spikes, 1).numpy()

    rec_spikes = [
        rec_dict["rec_spikes"][i, :, : int(rec_dict["signal_masks"][i, 0].sum().item())]
        for i in range(len(rec_dict["rec_spikes"]))
    ]
    rec_spikes = torch.cat(rec_spikes, 1).numpy()

    print(f"rec_spikes shape {rec_spikes.shape}, real_spikes shape {real_spikes.shape}")

    corrcoef_real = np.corrcoef(
        real_spikes,
        rowvar=True,
    )
    corrcoef_rec = np.corrcoef(
        rec_spikes,
        rowvar=True,
    )
    print(corrcoef_real.shape, corrcoef_rec.shape)

    np.fill_diagonal(corrcoef_real, 0.01)
    np.fill_diagonal(corrcoef_rec, 0.01)

    fig, axs = plt.subplots(1, 3, figsize=figsize)

    # Plot corrcoef_real
    axs[0].imshow(
        corrcoef_real,
        cmap="coolwarm",
        vmin=-1,
        vmax=1,
    )
    axs[0].set_title("neuron correlations gt")
    axs[0].axis("off")
    # colorbar
    cbar = plt.colorbar(
        axs[0].imshow(corrcoef_real, cmap="coolwarm"),
        ax=axs[0],
        orientation="vertical",
        fraction=0.046,
        pad=0.04,
        # ticks=[-1, 0, 1],
        boundaries=np.linspace(-1.01, 1.01, 50),
    )

    # Plot corrcoef_rec
    axs[1].imshow(
        corrcoef_rec,
        cmap="coolwarm",
        vmin=-1,
        vmax=1,
    )
    axs[1].set_title("neuron correlations ae")
    axs[1].axis("off")
    # colorbar
    cbar = plt.colorbar(
        axs[1].imshow(corrcoef_rec, cmap="coolwarm"),
        ax=axs[1],
        orientation="vertical",
        fraction=0.046,
        pad=0.04,
        boundaries=np.linspace(-1.01, 1.01, 50),
        # ticks=[-1, 0, 1],
    )

    # Plot difference
    axs[2].imshow(np.abs(corrcoef_rec - corrcoef_real), cmap="magma")
    axs[2].set_title("neuron |corr_real-corr_recon|")
    axs[2].axis("off")
    # all ticks
    # axs[2].set_yticks(np.arange(0, corrcoef_real.shape[0], 10), fontsize=4)

    # colorbar
    cbar = plt.colorbar(
        axs[2].imshow(np.abs(corrcoef_rec - corrcoef_real), cmap="magma"),
        orientation="vertical",
        fraction=0.046,
        pad=0.04,
        boundaries=np.linspace(0, 1.01, 50),
    )

    plt.tight_layout()
    return fig
    # plt.show()


# ------------- train loop --------------


if cfg.training.latent_reg_version == "v1":
    latent_regularizer_fn = latent_regularizer
elif cfg.training.latent_reg_version == "v2":
    latent_regularizer_fn = latent_regularizer_v2
elif cfg.training.latent_reg_version == "v3":
    raise NotImplementedError("DONT USE v3 yet!")
    latent_regularizer_fn = GPNLL(
        T=cfg.dataset.signal_length,
        lengthscale=cfg.training.get("gp_lengthscale", 1.0),
        covariance_eps=(
            1e-2 if cfg.training.gp_lengthscale <= 2 else 5e-2
        ),  # more stability needed for larger lengthscale
    ).to(device)

criterion_poisson = nn.PoissonNLLLoss(log_input=False, full=True, reduction="none")

rec_losses, latent_losses, total_losses, lrs, val_rate_losses = [], [], [], [], []
avg_poisson_loss, avg_rate_loss = 0, 0

wandb.init(project="ntldm-phoneme", entity="anon-project", name=cfg.exp_name)

# log hparams
cfg_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
wandb.config.update(cfg_dict)

# train loop
criterion_poisson = nn.PoissonNLLLoss(log_input=False, full=True, reduction="none")

rec_losses, latent_losses, total_losses, lrs, val_rate_losses = [], [], [], [], []
avg_poisson_loss, avg_rate_loss = 0, 0
import wandb

# os.environ["WANDB_MODE"] = "online"
wandb.init(project="ntldm", entity="anon-project")
with tqdm(range(0, cfg.training.num_epochs)) as pbar:
    for epoch in pbar:
        ae.train()

        for i, data in enumerate(train_dataloader):
            optimizer.zero_grad()

            signal = data["signal"]
            signal_mask = data["mask"]
            L_actual = signal_mask[:, 0, :].sum(-1)

            # applying mask (coordinated dropout)
            mask_prob = cfg.training.get("mask_prob", 0.25)

            mask = (
                torch.rand_like(signal[:]) > mask_prob
            ).float()  # if mask_prob=0.2, 80% will be 1 and rest 0
            input_signal = signal * (
                mask / (1 - mask_prob)
            )  # mask and scale unmasked by 1/(1-p)

            output_rates, z = ae(input_signal)

            numel = signal.shape[0] * signal.shape[1] * signal.shape[2]

            # computing loss on masked parts
            unmasked = (1 - mask) if mask_prob > 0 else torch.ones_like(mask)
            poisson_loss = criterion_poisson(output_rates, signal) * unmasked

            poisson_loss = poisson_loss * signal_mask  # also mask out padding

            poisson_loss = poisson_loss.mean()

            rec_loss = poisson_loss

            latent_loss = latent_regularizer_fn(z, cfg) / numel
            loss = rec_loss + cfg.training.latent_beta * latent_loss

            accelerator.backward(loss)
            accelerator.clip_grad_norm_(ae.parameters(), 2.0)

            optimizer.step()
            lr_scheduler.step()

            pbar.set_postfix(
                **{
                    "rec_loss": rec_loss.item(),
                    "latent_loss": latent_loss.item(),
                    "total_loss": loss.item(),
                    "lr": optimizer.param_groups[0]["lr"],
                    "val_poisson_loss": avg_poisson_loss,
                }
            )
            rec_losses.append(rec_loss.item())
            latent_losses.append(latent_loss.item())
            total_losses.append(loss.item())
            lrs.append(optimizer.param_groups[0]["lr"])
            wandb.log(
                {
                    "rec_loss": rec_loss.item(),
                    "latent_loss": latent_loss.item(),
                    "total_loss": loss.item(),
                    "lr": optimizer.param_groups[0]["lr"],
                    "epoch": epoch,
                }
            )
        # eval

        if accelerator.is_main_process and (
            (epoch) % 10 == 0 or epoch == cfg.training.num_epochs - 1
        ):
            avg_poisson_loss = compute_val_loss(ae, val_dataloader)
            wandb.log({"val_poisson_loss": avg_poisson_loss})
        if accelerator.is_main_process and (
            (epoch) % 20 == 0 or epoch == cfg.training.num_epochs - 1
        ):

            ae.eval()

            # plot_rate_traces_real(ae, val_dataloader, figsize=(12, 5), idx=1)
            fig = imshow_rates_real(ae, val_dataloader, figsize=(12, 5), idx=1)
            wandb.log({"val_rates and spikes": wandb.Image(fig)})

            rec_dict = reconstruct_spikes(ae, val_dataloader)

            # plot reconstructed spikes
            plt.figure(figsize=cm2inch((6, 4)))
            # bins = np.linspace(0, 20, 20)
            plt.hist(
                rec_dict["spikes"][:, :].sum(1).flatten(),
                density=True,
                color="grey",
                bins=100,
                alpha=0.5,
            )
            plt.hist(
                rec_dict["rec_spikes"][:, :].sum(1).flatten(),
                density=True,
                color="darkblue",
                bins=100,
                alpha=0.5,
            )

            plt.xlim(0, 1000)

            plt.legend(["gt", "ae"])

            plt.title("spike count distribution (val set)")

            wandb.log({"spike_count_dist": wandb.Image(plt.gcf())})

            # avg neuron firing rate scatterplot
            fig, ax = plt.subplots(1, 1, figsize=cm2inch((6, 4)))
            ax.scatter(
                rec_dict["spikes"][:, :].mean((0, 2)).flatten(),
                rec_dict["rec_spikes"][:, :].mean((0, 2)).flatten(),
                alpha=0.5,
            )
            ax.set_xlabel("gt mean spike rate")
            ax.set_ylabel("ae mean spike rate")
            ax.plot(
                [0, rec_dict["spikes"][:, :].mean((0, 2)).max()],
                [0, rec_dict["spikes"][:, :].mean((0, 2)).flatten().max()],
                color="black",
            )
            fig.tight_layout()

            wandb.log({"firing rate_scatter": wandb.Image(fig)})

            fig = plot_corrcoef(rec_dict)
            wandb.log({"corrcoef": wandb.Image(fig)})

            accelerator.save_state(f"exp/{cfg.exp_name}/epoch_{(epoch+20)//20*20}")


###---- end of training loop ----
