"""Modified https://github.com/bethgelab/slow_disentanglement/blob/master/scripts/solver.py
and removed functions not needed for training with a contrastive loss."""

import os
import shutil
import torch
import torch.optim as optim
from torch.autograd import Variable
from kitti_masks.model import BetaVAE_H as BetaVAE
import losses
from torch.utils.tensorboard import SummaryWriter
from kitti_masks.dataset import return_data
from kitti_masks.evaluate_disentanglement import eval_online


class Solver(object):
    def __init__(self, args, data_loader=None):
        self.ckpt_dir = args.ckpt_dir
        self.output_dir = args.output_dir
        self.data_loader = data_loader
        self.dataset = args.dataset
        self.device = torch.device(
            "cuda:" + str(args.gpu) if torch.cuda.is_available() and args.cuda else "cpu"
        )
        self.max_iter = args.max_iter
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.nc = args.num_channel
        params = []

        # for adam
        self.lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2

        self.net = BetaVAE(self.z_dim, self.nc, args.box_norm).to(self.device)

        self.ckpt_name = args.ckpt_name
        if False and self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)

        self.log_step = args.log_step
        self.save_step = args.save_step

        if args.loss == 'simclr':
            self.loss = losses.LpSimCLRLoss(
                p=args.p, tau=1.0, simclr_compatibility_mode=True
            )
        elif args.loss == 'ince':
            self.loss = losses.DeltaINCELoss(size=args.z_dim, p=args.p, tau=0.1, margin_mode='first', center=args.center, device=self.device)
        elif args.loss == 'nce':
            self.loss = losses.DeltaNCELoss(size=args.z_dim, p=args.p, tau=0.1, margin_mode='first', center=args.center, device=self.device)
        elif args.loss == 'nwj':
            self.loss = losses.DeltaNWJLoss(size=args.z_dim, p=args.p, tau=0.1, margin_mode='first', center=args.center, device=self.device)
        elif args.loss == 'scl':
            self.loss = losses.DeltaSCLLoss(size=args.z_dim, p=args.p, tau=0.1, margin_mode='first', center=args.center, device=self.device)
        else:
            raise NotImplementedError

        if args.loss == 'simclr':
            self.optim = optim.Adam(
                params + list(self.net.parameters()),
                lr=self.lr,
                betas=(self.beta1, self.beta2),
            )
        else:
            self.optim = optim.Adam(
                [
                    {'params': self.net.parameters(), 'lr': self.lr},
                    {'params': [x for (n, x) in self.loss.critic.named_parameters() if n!='c'], 'lr': 100.0 * self.lr},
                    {'params': self.loss.critic.c, 'lr': self.lr},                    
                ],
                lr=self.lr,
                betas=(self.beta1, self.beta2),
            )
        self.args = args

    def train(self):
        self.net_mode(train=True)
        out = False  # whether to exit training loop
        failure = False  # whether training was stopped
        running_loss = 0
        writer = SummaryWriter(log_dir=self.output_dir)
        log = open(os.path.join(self.output_dir, "log.csv"), "a", 1)
        log.write("Total Loss\n")

        while not out:
            for x, _ in self.data_loader:  # don't use label
                x = Variable(x.to(self.device))

                mu = self.net(x)
                z1_rec = mu[::2]
                z2_con_z1_rec = mu[1::2]
                z3_rec = torch.roll(z1_rec, 1, 0)
                vae_loss, _, _ = self.loss(
                    None, None, None, z1_rec, z2_con_z1_rec, z3_rec
                )
                # vae_loss, _, _ = self.loss(
                #     None, None, None, z1_rec, z2_con_z1_rec, target
                # )
                running_loss += vae_loss.item()

                self.optim.zero_grad()
                vae_loss.backward()
                self.optim.step()

                self.global_iter += 1
                if self.global_iter % self.log_step == 0:
                    running_loss /= self.log_step
                    log.write("%.6f" % running_loss + "\n")

                    if self.args.loss != 'simclr':
                        critic_params = self.loss.critic.get_param()
                        for key, val in critic_params.items():
                            val = val.view(-1).numpy()
                            print(key, val)
                            for i, e in enumerate(val):
                                writer.add_scalar(f"{key}/{i}", e, self.global_iter)

                    running_loss = 0

                if self.global_iter % self.save_step == 0:
                    self.save_checkpoint("last")

                if self.global_iter % 50000 == 0:
                    self.save_checkpoint(str(self.global_iter))

                if self.global_iter % (3 * self.save_step) == 0:
                # if self.global_iter % self.save_step == 0:
                    self.args.evaluate = True
                    data_loader, _ = return_data(self.args)
                    permutation_disentanglement_score = eval_online(self.args, data_loader.dataset)
                    writer.add_scalar('Perm. Disentanglement', permutation_disentanglement_score, self.global_iter)

                if self.global_iter >= self.max_iter:
                    out = True
                    break

        if failure:
            shutil.rmtree(self.ckpt_dir)

        return failure

    def save_checkpoint(self, filename, silent=True):
        model_states = {
            "net": self.net.state_dict(),
        }
        optim_states = {
            "optim": self.optim.state_dict(),
        }
        states = {
            "iter": self.global_iter,
            "model_states": model_states,
            "optim_states": optim_states,
        }

        file_path = os.path.join(self.ckpt_dir, filename)
        with open(file_path, mode="wb+") as f:
            torch.save(states, f)
        if not silent:
            print(
                "=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter)
            )

    def load_checkpoint(self, filename):
        file_path = os.path.join(self.ckpt_dir, filename)
        if os.path.isfile(file_path):
            checkpoint = torch.load(file_path)
            self.global_iter = checkpoint["iter"]
            self.net.load_state_dict(checkpoint["model_states"]["net"])
            self.optim.load_state_dict(checkpoint["optim_states"]["optim"])
            print(
                "=> loaded checkpoint '{} (iter {})'".format(
                    file_path, self.global_iter
                )
            )
        else:
            print("=> no checkpoint found at '{}'".format(file_path))

    def net_mode(self, train):
        if not isinstance(train, bool):
            raise ValueError("Only bool type is supported. True or False")

        if train:
            self.net.train()
        else:
            self.net.eval()
