from __future__ import annotations

import logging
import math
from typing import Any, Dict, Optional, Tuple
from pathlib import Path

import numpy as np
import torch
import torch.distributed as dist

from .data import DictMemmapWriter
from .train import Trainer
from .torch_util import get_world_size, move_to_device

log = logging.getLogger(__name__)

import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


class DataSelector(Trainer):

    def add_reference_model(self, reference_trainer: Trainer):
        assert not hasattr(self, "reference_model")
        self.reference_model = reference_trainer
        self.reference_model.fsdp_model.train()

    def add_learner_model(self, learner_trainer: Trainer):
        assert not hasattr(self, "learner_model")
        self.learner_model = learner_trainer
        self.learner_model.fsdp_model.train()

    def score(self, reference_loss: torch.Tensor, learner_loss: torch.Tensor):
        # Select tokens/seqs
        if self.cfg.score is None or self.cfg.score == "rho":
            score = (learner_loss - reference_loss).detach()
        elif self.cfg.score == "ref":
            score = -reference_loss.detach()
        return score

    def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        # Split into micro-batches.
        micro_batches = self.split_batch(batch)

        # In case this helps with memory utilization.
        del batch

        # For logging
        learn_batch_loss = torch.tensor(0.0, device=self.device)
        ref_batch_loss = torch.tensor(0.0, device=self.device)
        full_learn_batch_loss = torch.tensor(0.0, device=self.device)
        full_ref_batch_loss = torch.tensor(0.0, device=self.device)
        online_batch_loss = torch.tensor(0.0, device=self.device)
        full_online_batch_loss = torch.tensor(0.0, device=self.device)
        online_z_batch_loss = (
            None if not self.cfg.softmax_auxiliary_loss else torch.tensor(0.0, device=self.device)
        )

        # To be saved
        batch_learn_losses = []
        batch_ref_losses = []
        for micro_batch in micro_batches:
            with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
                mbs = micro_batch["input_ids"].shape[0]
                if self.cfg.granularity == "token":
                    tokens = True
                else:
                    assert self.cfg.granularity == "sequence"
                    tokens = False

                # Run forward passes
                learn_loss, learn_z_loss = self.learner_model.model_forward(
                    micro_batch,
                    compute_z_loss=self.cfg.softmax_auxiliary_loss,
                    loss_reduction="none",
                    return_logits=False,
                )
                ref_loss, ref_z_loss = self.reference_model.model_forward(
                    micro_batch,
                    compute_z_loss=True,
                    loss_reduction="none",
                    return_logits=False,
                )
                # Store full losses on cpu
                batch_learn_losses.append(learn_loss.detach().cpu())
                batch_ref_losses.append(ref_loss.detach().cpu())

                # Reshape losses
                if tokens:
                    learn_loss = learn_loss.flatten()
                    ref_loss = ref_loss.flatten()
                    if self.cfg.softmax_auxiliary_loss:
                        learn_z_loss = learn_z_loss.flatten()
                        ref_z_loss = ref_z_loss.flatten()
                else:
                    learn_loss = learn_loss.mean(dim=-1)
                    ref_loss = ref_loss.mean(dim=-1)
                    if self.cfg.softmax_auxiliary_loss:
                        learn_z_loss = learn_z_loss.mean(dim=-1)
                        ref_z_loss = ref_z_loss.mean(dim=-1)

                # Select data
                score = self.score(ref_loss, learn_loss)
                k = int(self.cfg.select_frac * len(ref_loss))
                select_idx = torch.topk(score, k, largest=True).indices

                # Forward online model
                # NOTE: not actually saving compute here, forwarding whole batch
                # To save compute, load the stored index from disk.
                online_loss, online_z_loss = self.model_forward(
                    micro_batch,
                    compute_z_loss=self.cfg.softmax_auxiliary_loss,
                    loss_reduction="none",
                    return_logits=False,
                )
                if tokens:
                    online_loss = online_loss.flatten()
                    if self.cfg.softmax_auxiliary_loss:
                        online_z_loss = online_z_loss.flatten()
                else:
                    online_loss = online_loss.mean(dim=-1)
                    if self.cfg.softmax_auxiliary_loss:
                        online_z_loss = online_z_loss.mean(dim=-1)

                # Log full losses
                full_learn_batch_loss += learn_loss.mean().detach() / len(micro_batches)
                full_ref_batch_loss += ref_loss.mean().detach() / len(micro_batches)
                full_online_batch_loss += online_loss.mean().detach() / len(micro_batches)

                # Select by select_idx
                learn_loss = learn_loss[select_idx].mean() / len(micro_batches)
                ref_loss = ref_loss[select_idx].mean() / len(micro_batches)
                online_loss = online_loss[select_idx].mean() / len(micro_batches)
                if self.cfg.softmax_auxiliary_loss:
                    learn_z_loss = learn_z_loss[select_idx].mean() / len(micro_batches)
                    ref_z_loss = ref_z_loss[select_idx].mean() / len(micro_batches)
                    online_z_loss = online_z_loss[select_idx].mean() / len(micro_batches)

                # Log selected losses
                learn_batch_loss += learn_loss.detach()
                ref_batch_loss += ref_loss.detach()
                online_batch_loss += online_loss.detach()

                # In case this helps with memory utilization.
                del micro_batch

                # Get final losses to optimize for.
                if self.cfg.softmax_auxiliary_loss:
                    learn_loss = learn_loss + learn_z_loss
                    ref_loss = ref_loss + ref_z_loss
                    online_loss = online_loss + online_z_loss

                    # Update overall Z batch loss.
                    online_z_batch_loss += online_z_loss.detach()

            # Run backward pass.
            learn_loss.backward()
            if not self.cfg.fix_reference:
                ref_loss.backward()
            online_loss.backward()

        # Cat for logging purposes
        # TODO: check on this
        if batch_learn_losses:
            batch_learn_losses = torch.concatenate(batch_learn_losses, dim=0)
        if batch_ref_losses:
            batch_ref_losses = torch.concatenate(batch_ref_losses, dim=0)
        return (
            online_batch_loss,
            online_z_batch_loss,
            full_online_batch_loss,
            learn_batch_loss,
            full_learn_batch_loss,
            ref_batch_loss,
            full_ref_batch_loss,
            batch_learn_losses,
            batch_ref_losses,
        )

    def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        batch_data: Dict[str, torch.Tensor] = {}

        # Write data-indices to file.
        if self.indices_file is not None and "index" in batch:
            indices = "\t".join(str(int(i)) for i in batch["index"])
            self.indices_file.write(f"{self.global_step}\t{indices}\n")

        # Zero-gradients.
        self.optim.zero_grad(set_to_none=True)
        self.learner_model.optim.zero_grad(set_to_none=True)
        self.reference_model.optim.zero_grad(set_to_none=True)

        # Move tensors to the right device.
        batch = move_to_device(batch, self.device)

        # Run forward-backward pass.
        (
            online_batch_loss,
            online_z_batch_loss,
            full_online_batch_loss,
            learn_batch_loss,
            full_learn_batch_loss,
            ref_batch_loss,
            full_ref_batch_loss,
            batch_learn_losses,
            batch_ref_losses,
        ) = self.train_batch(batch)
        if not self.cfg.sft:
            batch_data["learn_losses"] = batch_learn_losses
            batch_data["ref_losses"] = batch_ref_losses
            batch_data["index"] = batch["index"].detach().cpu()

        # Collect loss, potentially reducing over all ranks.
        if reduce_global_loss:
            dist.reduce(online_batch_loss, 0)
            online_batch_loss.div_(get_world_size())
            dist.reduce(full_online_batch_loss, 0)
            full_online_batch_loss.div_(get_world_size())
            dist.reduce(learn_batch_loss, 0)
            learn_batch_loss.div_(get_world_size())
            dist.reduce(full_learn_batch_loss, 0)
            full_learn_batch_loss.div_(get_world_size())
            dist.reduce(ref_batch_loss, 0)
            ref_batch_loss.div_(get_world_size())
            dist.reduce(full_ref_batch_loss, 0)
            full_ref_batch_loss.div_(get_world_size())
            if online_z_batch_loss is not None:
                dist.reduce(online_z_batch_loss, 0)
                online_z_batch_loss.div_(get_world_size())

        # Clip gradient norms and collect param/gradient/optim metrics.
        should_log_optim_metrics_this_step = self.should_log_optim_metrics_this_step()
        optim_metrics = self.optim_step(should_log_metrics=should_log_optim_metrics_this_step)

        # Optim step for references
        # TODO: log these metrics?
        learn_optim_metrics = self.learner_model.optim_step(should_log_metrics=should_log_optim_metrics_this_step)
        if not self.cfg.fix_reference:
            ref_optim_metrics = self.reference_model.optim_step(
                should_log_metrics=should_log_optim_metrics_this_step
            )

        # For scheduler
        self.learner_model.global_step += 1
        self.reference_model.global_step += 1

        # Collect metrics and check for NaN loss.
        ce_batch_loss = online_batch_loss
        full_batch_loss = full_online_batch_loss
        z_batch_loss = online_z_batch_loss
        # NOTE: this involves a bunch of host-device syncs so we wait until the last moment to do this.
        if torch.isnan(ce_batch_loss):
            raise ValueError("nan loss encountered")
        if z_batch_loss is not None and torch.isnan(z_batch_loss):
            raise ValueError("nan loss encountered")
        for key, value in optim_metrics.items():
            metrics[f"optim/{key}"] = value.item()
        self.cur_train_loss = ce_batch_loss.item()
        self.min_train_loss = min(self.min_train_loss, self.cur_train_loss)
        metrics["train/CrossEntropyLoss"] = self.cur_train_loss
        metrics["train/Perplexity"] = math.exp(self.cur_train_loss)
        metrics["train/FullLoss"] = full_batch_loss.item()
        if z_batch_loss is not None:
            metrics["train/ZLoss"] = z_batch_loss.item()
        # Add logging for learner and reference
        metrics["train/ReferenceLoss"] = ref_batch_loss.item()
        metrics["train/LearnerLoss"] = learn_batch_loss.item()
        metrics["train/FullReferenceLoss"] = full_ref_batch_loss.item()
        metrics["train/FullLearnerLoss"] = full_learn_batch_loss.item()

        # Maybe collect post-step optimizer-specific metrics.
        if should_log_optim_metrics_this_step:
            optim_metrics = self.optim.get_post_step_metrics(
                self.fsdp_model, process_group=self.fsdp_model.process_group
            )
            for key, value in optim_metrics.items():
                metrics[f"optim/{key}"] = value.item()

        # Write learner scores to disk using dict writer
        if self.cfg.collect_learner_score:
            assert not self.cfg.sft
            self.learner_score_writer.write(batch_data["index"], batch_data["learn_losses"])
        if self.cfg.collect_reference_score:
            assert not self.cfg.sft
            self.reference_score_writer.write(batch_data["index"], batch_data["ref_losses"])
        return metrics

    def fit(self):

        if self.cfg.collect_learner_score:
            self.learner_score_writer = DictMemmapWriter(
                Path(self.cfg.save_folder) / "learn_score",
                memmap_dtype=np.float32,
                seq_len=self.cfg.model.max_sequence_length - 1,  # Losses are one token shorter that seq_len
            )
        if self.cfg.collect_reference_score:
            self.reference_score_writer = DictMemmapWriter(
                Path(self.cfg.save_folder) / "ref_score",
                memmap_dtype=np.float32,
                seq_len=self.cfg.model.max_sequence_length - 1,  # Losses are one token shorter that seq_len
            )

        super().fit()

        if self.cfg.collect_learner_score:
            self.learner_score_writer.close()
        if self.cfg.collect_reference_score:
            self.reference_score_writer.close()
