from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms

from itertools import chain

from core import metrics, dataloader, utils, samplers
from scripts.vae import models
from time import time
from core.logger import Logger

import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--nocuda", action='store_true', default=False)
parser.add_argument("--loss", type=str, default='cce')
parser.add_argument("--checkpoint", type=str, default='')
parser.add_argument("--seed", type=int, default=322)
flags = parser.parse_args()


def main():
    fmt = {'lr': '1.1e',
           'tr_loss': '.4f',
           'AR': '.4f',
           'time': '.3f'}

    torch.manual_seed(flags.seed)
    device = torch.device("cpu") if flags.nocuda else torch.device("cuda")
    discriminator = models.DiscriminatorCeleba()
    if flags.loss == 'cce':
        path_name = 'cross-ent'
        discriminator = models.DiscriminatorCCE(discriminator).to(device)
        criterion = metrics.ConventionalCrossEntropy(discriminator)
    elif flags.loss == 'ub':
        path_name = 'upper-bound'
        discriminator = models.DiscriminatorUB(discriminator).to(device)
        criterion = metrics.UpperBound(discriminator)
    elif flags.loss == 'mce':
        path_name = 'markov-ent'
        discriminator = models.DiscriminatorMCE(discriminator).to(device)
        criterion = metrics.MarkovCrossEntropy(discriminator)
    else:
        raise NotImplementedError('unrecognized loss')

    logger = Logger(base='./logs/VAE-MH-CELEBA', name='-{}-{}'.format(path_name, flags.seed), fmt=fmt)
    _, decoder = models.get_vae_celeba(device)
    encoder_dict, decoder_dict = torch.load(flags.checkpoint, map_location='cpu')
    decoder.load_state_dict(decoder_dict)

    dataroot = "../../data/celeba"
    image_size = 64
    trainset = datasets.ImageFolder(root=dataroot,
                                    transform=transforms.Compose([
                                       transforms.Resize(image_size),
                                       transforms.CenterCrop(image_size),
                                       transforms.ToTensor()
                                    ]))
    batch_size = 256
    lr_start = 1e-6
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
    optimizer = optim.Adam(discriminator.parameters(), lr=lr_start, betas=(0.5, 0.999))

    torch.save([decoder.state_dict(), discriminator.state_dict()], logger.get_checkpoint(0))
    epochs = 5
    for epoch in range(epochs):
        t0 = time()
        train_loss = 0.0
        AR = 0.0
        num_ims = 0
        for i, (real_images, labels) in enumerate(trainloader, 0):
            real_images, labels = real_images.to(device), labels.to(device)
            z = torch.randn([real_images.shape[0], decoder.h_dim]).to(device)
            fake_images = decoder(z).detach()
            loss = criterion(real_images, fake_images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_ar = discriminator.acceptance_ratio(fake_images, real_images).detach().cpu().numpy()
            AR += np.sum(np.minimum(batch_ar, 1))
            train_loss += loss.detach().cpu().numpy()
            num_ims += real_images.shape[0]
            if ((i + 1) % (len(trainloader) // 8)) == 0:
                iter = epoch*len(trainloader) + i + 1
                logger.add(iter, tr_loss=train_loss / num_ims)
                logger.add(iter, AR=AR / num_ims)
                logger.add(iter, lr=optimizer.param_groups[0]['lr'])
                logger.add(iter, time=time() - t0)
                logger.iter_info()
                logger.save(silent=True)
                torch.save([decoder.state_dict(), discriminator.state_dict()], logger.get_checkpoint(iter))

                t0 = time()
                train_loss = 0.0
                AR = 0.0
                num_ims = 0


if __name__ == '__main__':
    main()
