import pdb

import torch
import torch.nn.functional as F
import util.gan_util as gan_util
from data.transform import DiffRandAug
from ignite.utils import convert_tensor
from kornia.augmentation import Resize


class ClassifierUpdater:
    def __init__(self, *args, **kwargs):
        self.classifier = kwargs.pop("classifier")
        self.generator = kwargs.pop("generator")
        self.finder = kwargs.pop("finder")
        self.optimizer_c = kwargs.pop("optimizer_c")
        self.optimizer_f = kwargs.pop("optimizer_f")
        self.device = kwargs.pop("device")
        self.ema_model = kwargs.pop("ema_model")
        self.lambda_p = kwargs.pop("lambda_p")
        self.batchsize_p = kwargs.pop("batchsize_p")
        self.resolution = kwargs.pop("resolution")
        self.resizer = Resize(size=self.resolution)
        self.loss = F.cross_entropy
        self.augment = DiffRandAug(num_ops=2, normalized=True)

    def get_batch(self, batch):
        x, y = batch
        return (
            convert_tensor(x, device=self.device, non_blocking=True),
            convert_tensor(y, device=self.device, non_blocking=True),
        )

    def _sample_noize_and_label(self, n_gen_samples=None):
        gen = self.generator if (torch.cuda.device_count() < 2) else self.generator.module
        z = gan_util.sample_z(gen, n_gen_samples, self.device)
        y = gan_util.sample_categorical_labels(gen.num_classes, n_gen_samples, self.device)
        return z, y

    def __call__(self, engine, batch):
        report = {}
        self.classifier.train()

        # Generate pseudo samples and all logits
        x, y = self.get_batch(batch)
        batchsize = x.shape[0]
        z_p, y_p = self._sample_noize_and_label(n_gen_samples=self.batchsize_p)
        x_p = self.resizer(self.generator(z_p, y_p))
        x_p_w, x_p_s = x_p.detach(), self.augment(x_p).detach()
        images = torch.cat([x, x_p_w, x_p_s], dim=0)
        logit_all, feat_all = self.classifier(images)
        logit_real, _, _ = torch.split(logit_all, [batchsize, self.batchsize_p, self.batchsize_p], dim=0)
        _, feat_p_w, feat_p_s = torch.split(feat_all, [batchsize, self.batchsize_p, self.batchsize_p], dim=0)

        # Calculate supervised loss
        loss_supervised = self.loss(logit_real, y)
        entropy = (-F.softmax(logit_real.detach(), dim=1) * F.log_softmax(logit_real.detach(), dim=1)).sum(dim=1).mean()
        report.update({"y_pred": logit_real.detach()})
        report.update({"y": y.detach()})
        report.update({"loss": loss_supervised.detach().item()})
        report.update({"entropy": entropy.detach().item()})

        # Calculate unsupervised loss
        loss_pseudo = F.mse_loss(feat_p_w, feat_p_s)
        loss_log = loss_pseudo.detach().item()
        report.update({"loss_pseudo": loss_log})

        loss_target = loss_supervised + self.lambda_p * loss_pseudo
        self.optimizer_c.zero_grad()
        loss_target.backward()
        self.optimizer_c.step()

        del x_p_w, x_p_s, logit_all, feat_all

        if self.ema_model is not None:
            self.ema_model.update(self.classifier)
        return report
