import gzip
import copy
import pickle
import os.path
import sys

import numpy as np

from tqdm import tqdm
import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F

from options import args_parser
from model import vae_mini, vie

from torchvision.models import resnet18
from utils.dataset_util import random_inputs_sampler

from utils.cifar10_dataset import CIFAR10SubSet
from utils.cifar100_dataset import CIFAR100SubSet



""" TESTING on CIFAR-10"""

def get_schedulers(scheduler, optimizer, milestones=[30,80], gamma=0.5, T_max=10, lr_mul=0.001, d_model=10, n_warmup_steps=5):
    if scheduler == "step":
        return torch.optim.lr_scheduler.StepLR(optimizer, 30, gamma=gamma)
    elif scheduler == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    elif scheduler == "exponential":
        return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)


def main():
    """Test SplitNN"""
    args = args_parser()
    lr = 1e-2
    epoch_max = 500
    epoch_finetune = 5
    bs = 32
    Q = 1
    if args.dataset == 'cifar100':
        num_class = 100
    else:
        num_class = 10

    train_size = 50000

    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Training on device", device)


    net = resnet18()
    net.fc = nn.Linear(net.fc.in_features, num_class)

    decoder = nn.Sequential(
            # input is Z
            nn.ConvTranspose2d(64, 3, 8, 2, 3),
            nn.Sigmoid())


    class Flatten(nn.Module):
        def __init__(self):
            super(Flatten, self).__init__()
            
        def forward(self, x):
            x = x.view(x.size(0), -1)
            return x

    modules = list(net.children())[:args.num_head_layer]
    head = nn.Sequential(*modules)


    #head.load_state_dict(ckp_vae["head"])
    #vae.load_state_dict(ckp_vae["vae"])

    modules = list(net.children())[:-3]
    encoder = nn.Sequential(*modules)
    #encoder.load_state_dict(ckp['encoder'])
    modules = list(encoder.children())[args.num_head_layer:]
    encoder = nn.Sequential(*modules)

    modules = list(net.children())[-3:-1]
    clf = nn.Sequential(*[*modules, Flatten(), list(net.children())[-1]])

    clf_adv = nn.Sequential(*[Flatten(), nn.Linear(1024, 1024), nn.ReLU(), nn.Linear(1024, 256), nn.ReLU(), nn.Linear(256, num_class)])
    clf_adv_sim = nn.Sequential(*[Flatten(), nn.Linear(1024, num_class)])

    encoder.to(device)
    clf.to(device)
    head.to(device)
    decoder.to(device)
    clf_adv.to(device)

    clf_adv_init = copy.deepcopy(clf_adv)
    clf_adv_sim_init = copy.deepcopy(clf_adv_sim)

    '''
    if save_models_filename:
        if os.path.exists(save_models_filename):
            print(f"Restoring models from {save_models_filename}")
            data = torch.load(save_models_filename)
            # print("net1", net1.state_dict().keys())
            # print("data['net1']", data["net1"].keys())
            # net1.load_state_dict(data["net1"])
            # net2.load_state_dict(data["net2"])
            net1_a = data["net1_a"]
            net1_b = data["net1_b"]
            net2 = data["net2"]
    '''
    optim_head = optim.SGD(head.parameters(), lr=lr, momentum=0.9)
    if args.mc == "passive":
        optim_encoder = optim.SGD(encoder.parameters(), lr=lr, momentum=0.9)
    else:
        optim_encoder = optim.SGD(encoder.parameters(), lr=lr, momentum=0.9, dampening=0.1)
    optim_clf = optim.SGD(clf.parameters(), lr=lr, momentum=0.9)
    optim_decoder = optim.SGD(decoder.parameters(), lr=lr, momentum=0.9)
    optim_clfadv = optim.SGD(clf_adv.parameters(), lr=lr, momentum=0.9)


    '''
    transform_train = transforms.Compose(
        [
            transforms.ToTensor(),
            #transforms.ToPILImage(),
            #transforms.Pad(4, padding_mode="reflect"),
            #transforms.RandomCrop(32),
            #transforms.RandomHorizontalFlip(),
            #transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    transform_valid = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
            ),
        ]
    )
    '''
    transform_train = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    transform_valid = transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    # transform_train = transforms.Compose(
    #      [transforms.ToTensor()])
    # transform_valid = transforms.Compose(
    #      [transforms.ToTensor()])

    if args.dataset == "cifar100":

        train_dataset = datasets.CIFAR100(
            root="../data/cifar100",
            train=True,
            download=True,
            transform=transform_train,
        ) 

        mc_train_dataset = CIFAR100SubSet(
            root="../data/cifar100",
            train=True,
            download=True,
            transform=transform_train,
            returns="all",
            num_sample=args.num_mc_sample
        )

        mc_train_dataset_med = CIFAR100SubSet(
            root="../data/cifar100",
            train=True,
            download=True,
            transform=transform_train,
            returns="all",
            num_sample=10*args.num_mc_sample
        )

        valid_dataset = datasets.CIFAR100(
            root="../data/cifar100",
            train=False,
            download=True,
            transform=transform_valid,
        )

    else:

        train_dataset = datasets.CIFAR10(
            root="../data/cifar10",
            train=True,
            download=True,
            transform=transform_train,
        ) 

        mc_train_dataset = CIFAR10SubSet(
            root="../data/cifar10",
            train=True,
            download=True,
            transform=transform_train,
            returns="all",
            num_sample=args.num_mc_sample
        )

        mc_train_dataset_med = CIFAR10SubSet(
            root="../data/cifar10",
            train=True,
            download=True,
            transform=transform_train,
            returns="all",
            num_sample=10*args.num_mc_sample
        )

        valid_dataset = datasets.CIFAR10(
            root="../data/cifar10",
            train=False,
            download=True,
            transform=transform_valid,
        )

    rand_image_sampler = random_inputs_sampler(train_dataset)


    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=bs, shuffle=False, num_workers=2
    )

    mc_train_loader = torch.utils.data.DataLoader(
        mc_train_dataset, batch_size=bs, shuffle=True, num_workers=2
    )

    mc_train_loader_med = torch.utils.data.DataLoader(
        mc_train_dataset_med, batch_size=bs, shuffle=True, num_workers=2
    )

    def mc_attack(clf_adv_init, mc_train_loader, valid_loader, device):
        clf_adv_ft = copy.deepcopy(clf_adv_init)
        clf_adv_ft.to(device)
        optim_clfadvft = optim.SGD(clf_adv_ft.parameters(), lr=lr, momentum=0.9)
        for _ in range(epoch_finetune):
            for batch_idx, (inputs, targets) in enumerate(mc_train_loader):
                #pbar.set_description("Epoch {}".format(e+1))

                inputs, targets = inputs.to(device), targets.to(device)

                encoder.train()
                clf_adv_ft.train()

                optim_head.zero_grad()
                optim_encoder.zero_grad()
                optim_clfadvft.zero_grad()

                feature = encoder(head(inputs))
                feature_sent = feature.detach().requires_grad_()
                pred_adv = clf_adv_ft(feature_sent)

                loss_adv = criterion(pred_adv, targets)

                loss_adv.backward(retain_graph=True)
                optim_clfadvft.step()

        adv_train_acc = valid(head, encoder, clf_adv_ft, mc_train_loader, device)
        adv_val_acc = valid(head, encoder, clf_adv_ft, valid_loader, device)
        return adv_train_acc, adv_val_acc


    def valid(head, encoder, clf, data_loader, device, quit=False):
        encoder.eval()
        clf.eval()
 
        with torch.no_grad():
            correct, total = 0, 0
            for i, (inputs, labels) in enumerate(data_loader): 
                inputs, labels = inputs.to(device), labels.to(device)
                rep_send = head(inputs)

                outputs = clf(encoder(rep_send))
                _, pred_label = torch.max(outputs.data, 1)

                total += inputs.data.size()[0]
                correct += (pred_label == labels.data).sum().item()
            metric = correct / float(total)
        return metric

    best_acc = 0
    for e in range(epoch_max):
        clf.train()
        head.train()
        encoder.train()
        loss_list = []
        loss_decoder_list = []
        #pbar = tqdm(enumerate(train_loader))
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            #pbar.set_description("Epoch {}".format(e+1))

            inputs, targets = inputs.to(device), targets.to(device)
            

            # step 1: train decoder, clf_adv and clf
            optim_head.zero_grad()
            optim_clf.zero_grad()
            optim_encoder.zero_grad()
            optim_clfadv.zero_grad()
            optim_decoder.zero_grad()

            rep = head(inputs)
            rep_sent = rep.detach().requires_grad_()

            feature = encoder(rep_sent)
            feature_sent = feature.detach().requires_grad_()

            

            pred = clf(feature_sent)
            pred_adv = clf_adv(feature_sent)
            rec_inputs = decoder(rep_sent)

            loss_ce = criterion(pred, targets)
            loss_decoder = F.mse_loss(rec_inputs, inputs)
            loss_ce_adv = criterion(pred_adv, targets)

            loss_ce.backward(retain_graph=True)
            loss_decoder.backward(retain_graph=True)
            loss_ce_adv.backward(retain_graph=True)
            optim_clf.step()
            optim_decoder.step()
            optim_clfadv.step()

            optim_decoder.zero_grad()
            optim_encoder.zero_grad()
            optim_head.zero_grad()
            optim_clfadv.zero_grad()
            optim_clf.zero_grad()
            rep_sent.grad.zero_()
            feature_sent.grad.zero_()

            # step 2: train head and encoder
            pred = clf(feature_sent)
            pred_adv = clf_adv(feature_sent)
            rec_inputs = decoder(rep_sent)
            random_inputs = rand_image_sampler.sample(targets).to(device)
            targets_random = torch.LongTensor(np.random.choice(range(num_class), size=len(targets))).to(device)
            loss = (1-2*args.l_adv)*criterion(pred, targets) + args.l_adv*(-F.mse_loss(rec_inputs, inputs)) + args.l_adv*F.mse_loss(rec_inputs, random_inputs) - args.l_adv*criterion(pred_adv, targets) + args.l_adv*criterion(pred_adv, targets_random)
            loss.backward()
            feature.backward(gradient = feature_sent.grad)
            rep.backward(gradient = rep_sent.grad)
            optim_head.step()
            optim_encoder.step()
            
            loss_list.append(loss.item())
            loss_decoder_list.append(loss_decoder.item())

        if e%10 == 0:

            # finetune clf_adv to check the label leakage
            
            adv_train_acc, adv_val_acc = mc_attack(clf_adv_init, mc_train_loader, valid_loader, device)
            adv_train_acc_med, adv_val_acc_med = mc_attack(clf_adv_init, mc_train_loader_med, valid_loader, device)
            adv_train_acc_all, adv_val_acc_all = mc_attack(clf_adv_init, train_loader, valid_loader, device)
            adv_train_acc_sim, adv_val_acc_sim = mc_attack(clf_adv_sim_init, mc_train_loader, valid_loader, device)
            adv_train_acc_med_sim, adv_val_acc_med_sim = mc_attack(clf_adv_sim_init, mc_train_loader_med, valid_loader, device)




            clf.eval()
            head.eval()
            encoder.eval()
            decoder.eval()
            train_acc = valid(head, encoder, clf, train_loader, device)
            val_acc = valid(head, encoder, clf, valid_loader, device)

            print(
                f"Epoch: {e}. loss: {sum(loss_list)/len(loss_list):.4f}, loss_decoder: {sum(loss_decoder_list)/len(loss_decoder_list):.4f}, train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}"
            )
            if args.dataset == "cifar10":
                print(f"Adversarial 40 samples train acc: {adv_train_acc:.4f}, val acc: {adv_val_acc:.4f}")
                print(f"Adversarial 400 samples train acc: {adv_train_acc_med:.4f}, val acc: {adv_val_acc_med:.4f}")
                print(f"Adversarial all samples train acc: {adv_train_acc_all:.4f}, val acc: {adv_val_acc_all:.4f}")
                print(f"Adversarial 40 samples sim train acc: {adv_train_acc_sim:.4f}, val acc: {adv_val_acc_sim:.4f}")
                print(f"Adversarial 400 samples sim train acc: {adv_train_acc_med_sim:.4f}, val acc: {adv_val_acc_med_sim:.4f}")
            else:
                print(f"Adversarial 400 samples train acc: {adv_train_acc:.4f}, val acc: {adv_val_acc:.4f}")
                print(f"Adversarial 4000 samples train acc: {adv_train_acc_med:.4f}, val acc: {adv_val_acc_med:.4f}")
                print(f"Adversarial all samples train acc: {adv_train_acc_all:.4f}, val acc: {adv_val_acc_all:.4f}")
                print(f"Adversarial 400 samples sim train acc: {adv_train_acc_sim:.4f}, val acc: {adv_val_acc_sim:.4f}")
                print(f"Adversarial 4000 samples sim train acc: {adv_train_acc_med_sim:.4f}, val acc: {adv_val_acc_med_sim:.4f}")




            if val_acc>best_acc:
                best_acc = val_acc
                torch.save({"head": head.state_dict(), "encoder":encoder.state_dict(), "clf":clf.state_dict()}, args.save_models_filename)




if __name__ == "__main__":
    main()
