import pdb
from collections import OrderedDict

import torch
import torch.nn.functional as F
import util.gan_util as gan_util
from ignite.utils import convert_tensor
from kornia.augmentation import Resize
from util.meta_util import bidirectional_gradient_updated_parameters


class MPSClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.generator = kwargs.pop("generator")
        self.finder = kwargs.pop("finder")
        self.optimizer_c = kwargs.pop("optimizer_c")
        self.optimizer_f = kwargs.pop("optimizer_f")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.lambda_latent = kwargs.pop("lambda_latent") if "lambda_latent" in kwargs else 0.0
        self.lambda_classifier = kwargs.pop("lambda_classifier") if "lambda_classifier" in kwargs else 0.0
        self.fixed_inner_lr = kwargs.pop("fixed_inner_lr") if "fixed_inner_lr" in kwargs else None
        self.lambda_inner_lr = kwargs.pop("lambda_inner_lr") if "lambda_inner_lr" in kwargs else 1.0
        self.latent_reg = kwargs.pop("latent_reg") if "latent_reg" in kwargs else "norm"
        self.r = kwargs.pop("r") if "r" in kwargs else 1e-2
        self.u_accum_count = kwargs.pop("ubatch_ratio")
        self.batchsize_p = kwargs.pop("batchsize_p")
        self.warmup_epoch = kwargs.pop("warmup_epoch")
        self.resolution = kwargs.pop("resolution")
        self.resizer = Resize(size=self.resolution)
        self.meta_learning_freq = kwargs.pop("meta_learning_freq")
        self.val_loader = kwargs.pop("val_loader")
        self.val_loader_iter = iter(self.val_loader)
        self.loss = F.cross_entropy
        self.last_loss_mps = 0

    def latent_regularization(self, fz, eps=1e-7):
        if self.latent_reg == "norm":
            return torch.norm(fz, dim=1).mean()
        elif self.latent_reg == "kl":
            approx_mean = torch.mean(fz, dim=1).mean()
            approx_var = torch.var(fz, dim=1).mean()
            kl_div = -0.5 * (1 + torch.log(approx_var + eps) - approx_mean.pow(2) - approx_var)
            return kl_div
        else:
            raise NotImplementedError

    def sample_val_batch(self):
        try:
            batch = next(self.val_loader_iter)
        except StopIteration:
            self.val_loader_iter = iter(self.val_loader)
            batch = next(self.val_loader_iter)
        x_val, y_val = batch
        return (
            convert_tensor(x_val, device=self.device, non_blocking=True),
            convert_tensor(y_val, device=self.device, non_blocking=True),
        )

    def get_batch(self, batch):
        x, y = batch
        return (
            convert_tensor(x, device=self.device, non_blocking=True),
            convert_tensor(y, device=self.device, non_blocking=True),
        )

    def _sample_noize_and_label(self, n_gen_samples=None):
        if n_gen_samples is None:
            n_gen_samples = self.n_gen_samples
        gen = self.generator if (torch.cuda.device_count() < 2) else self.generator.module
        z = gan_util.sample_z(gen, n_gen_samples, self.device)
        y = gan_util.sample_categorical_labels(gen.num_classes, n_gen_samples, self.device)
        return z, y

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()
        self.finder.train()

        # Get train samples, sample noises and labels
        x, y = self.get_batch(batch)
        batchsize = x.shape[0]

        # 1. Meta-train finder to generate useful samples for classifier
        self.classifier.eval()
        x_v, y_v = self.sample_val_batch()
        logit_v = self.classifier(x_v)
        z_p, y_p = self._sample_noize_and_label(n_gen_samples=self.batchsize_p)
        fz_p = self.finder(z_p)
        x_p = self.resizer(self.generator(fz_p, y_p))
        # Calculate Approximated MPS loss
        loss_val = F.cross_entropy(logit_v, y_v)
        theta_plus, theta_minus, epsilon = bidirectional_gradient_updated_parameters(self.classifier, loss_val, first_order=True, r=self.r)
        logit_plus = self.classifier(x_p, params=theta_plus)
        logit_minus = self.classifier(x_p, params=theta_minus)
        loss_plus = self.loss(logit_plus, y_p)
        loss_minus = self.loss(logit_minus, y_p)
        inner_lr = self.fixed_inner_lr if self.fixed_inner_lr is not None else self.optimizer_c.param_groups[0]['lr']
        inner_lr = self.lambda_inner_lr * inner_lr
        loss_mps = (loss_minus - loss_plus).div(2 * epsilon).mul(inner_lr)
        # Calculate latent regularization loss
        loss_latent_reg = self.latent_regularization(fz_p)
        loss_all = loss_mps + self.lambda_latent * loss_latent_reg
        self.optimizer_f.zero_grad()
        loss_all.backward()
        self.optimizer_f.step()
        self.last_loss_mps = loss_mps.detach().item()
        report.update({"loss_mps": self.last_loss_mps})
        report.update({"latent_norm": loss_latent_reg.detach().item()})
        self.classifier.train()
        del loss_mps, logit_v, logit_plus, logit_minus, theta_plus, theta_minus

        # 2. Train classifier with pseudo semi-supervised learning
        # Generate pseudo samples and all logits
        with torch.no_grad():
            x_p = self.resizer(self.generator(self.finder(z_p), y_p))
        images = torch.cat([x, x_p.detach()], dim=0)
        logit_all = self.classifier(images)
        logit_real, logit_p = torch.split(logit_all, [batchsize, self.batchsize_p], dim=0)

        # Calculate supervised loss
        loss_supervised = self.loss(logit_real, y)
        report.update({"y_pred": logit_real.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss_supervised.detach()})

        # Calculate pseudo supervised loss
        loss_pseudo = self.loss(logit_p, y_p)
        lambda_p = self.lambda_p if self.warmup_epoch < engine.state.epoch else 0.0
        loss_pseudo = lambda_p * loss_pseudo
        loss_log = loss_pseudo.detach().item()
        report.update({"loss_pseudo": loss_log})

        # Calculate all losses and update classifier
        loss_target = loss_supervised + loss_pseudo
        self.optimizer_c.zero_grad()
        loss_target.backward()
        self.optimizer_c.step()

        if self.ema_model is not None:
            self.ema_model.update(self.classifier)

        return report
