import time
import torch
from PIL import Image
import numpy as np
import time, os, sys, copy
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score
from torch.autograd import Variable
import torch.nn.functional as F
from datetime import datetime as dt
from datetime import timedelta as td
import pickle
import os
import argparse
import pandas as pd

from densenet import DenseNet121
from models import resnet18
from dataset_cifar import CIFAR100, CIFAR10, to_multi_label, ImageDataset, CIFAR_imratio
from losses import AUCM_MultiLabel_selectTasks, CrossEntropyBinaryLoss_MultiLabel, AUCLoss_multiLabel
from chexpert import CheXpert
from celeba import CelebaDataset

dtype = torch.cuda.FloatTensor

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def zero_grad(model):
    for name, p in model.named_parameters():
        if p.grad is not None:
            p.grad.data.zero_()

def proj_sca(x, bound):
    if x > bound:
        return bound
    elif x < 0:
        return 0
    else:
        return x

def auc_score(true, pred):
    try:
        score = roc_auc_score(true, pred)
    except:
        score = 0
    return score

def evaluate(loader, model):
    pred = []
    true = []
    for j, (data, targets) in enumerate(loader):
        data = data.cuda()
        outputs = model(data)
        y_pred = torch.sigmoid(outputs)
        pred.append(y_pred.cpu().detach().numpy())
        true.append(targets.numpy())
    true = np.concatenate(true)
    pred = np.concatenate(pred)
    score = auc_score(true, pred)
    return score

parser = argparse.ArgumentParser(description='SMMMB_multiLabel')

parser.add_argument('--SEED', default=123, type=int)
parser.add_argument('--BATCH_SIZE_perTask', default=32, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay (default: 1e-4)')
parser.add_argument('--margin', default=1.0, type=float)
parser.add_argument('--dataset', default='celeba', type=str, help='cifar100, celeba, chexpert')
parser.add_argument('--inner_update_steps', default=1, type=int)
parser.add_argument('--alpha_proj_bd', default=1000, type=int)
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--beta', default=0.9, type=float)
parser.add_argument('--task_BATCH_SIZE', default=1, type=int)
parser.add_argument('--decay_point', default=10000, type=int)
parser.add_argument('--decay_rate', default=0.1, type=float)
parser.add_argument('--methodOption', default='auc', type=str)
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--total_epoch', default=40, type=int)
parser.add_argument('--beta_ct', default=0.9, type=float)
parser.add_argument('--gpu_id', default='0', type=str, help='id(s) for CUDA_VISIBLE_DEVICES')

def main():
    global args
    args = parser.parse_args()

    # Use CUDA
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id

    output_dict = SMMMB_multiLabel_noReg()


def SMMMB_multiLabel_noReg():
    # ----------------------
    ###### dataloader
    BATCH_SIZE = args.BATCH_SIZE_perTask * args.task_BATCH_SIZE

    NUM_CLASSES = {'cifar100': 100, 'celeba':40, 'chexpert':13}[args.dataset]
    num_train_samples = {'cifar100': 45000, 'celeba': 162770, 'chexpert': 191028}[args.dataset]

    ### data file path
    if args.dataset == 'celeba':
        root = '../celeba/'
    elif args.dataset == 'chexpert':
        root = '../CheXpert-v1.0-small/'

    if args.dataset == 'chexpert':
        ###### Chexpert Dataloader ######
        trainSet = CheXpert(csv_path=root + 'train.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                             image_size=224, mode='train', class_index=-1, data_split='train')
        valSet = CheXpert(csv_path=root + 'train.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                          image_size=224, mode='valid', class_index=-1, data_split='valid')
        testSet = CheXpert(csv_path=root + 'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=True,
                           image_size=224, mode='valid', class_index=-1, data_split='test')
        trainloader = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, num_workers=args.num_workers, shuffle=True)
        if args.methodOption != 'auc':
            trainloader2 = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, num_workers=args.num_workers, shuffle=True)
        valloader = torch.utils.data.DataLoader(valSet, batch_size=BATCH_SIZE, num_workers=args.num_workers,
                                                 shuffle=False)
        testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, num_workers=args.num_workers, shuffle=False)

        imratio = trainSet.imratio_list

    elif args.dataset == 'celeba':
        ###### Celeba Dataloader ######
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = CelebaDataset(
            root + 'celeba_attr_train.csv',
            root + 'img_align_celeba/',
            transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = CelebaDataset(root + 'celeba_attr_val.csv', root + 'img_align_celeba/',
                                    transforms.Compose([
                                        transforms.ToTensor(),
                                        normalize,
                                    ]))

        test_dataset = CelebaDataset(root + 'celeba_attr_test.csv', root + 'img_align_celeba/',
                                     transforms.Compose([
                                         transforms.ToTensor(),
                                         normalize,
                                     ]))

        train_sampler = None
        trainloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=BATCH_SIZE, shuffle=(train_sampler is None),
            num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

        if args.methodOption != "auc":
            trainloader2 = torch.utils.data.DataLoader(
                train_dataset, batch_size=BATCH_SIZE, shuffle=(train_sampler is None),
                num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

        valloader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE, shuffle=False,
            num_workers=args.num_workers, pin_memory=True)

        testloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=BATCH_SIZE, shuffle=False,
            num_workers=args.num_workers, pin_memory=True)

        imratio = train_dataset.imratio_list

    elif args.dataset == 'cifar100':
        ###### CIFAR100 dataloader #####
        IMG_SIZE = 32
        imratio = CIFAR_imratio(dataset=args.dataset)

        (train_data, train_label), (val_data, val_label), (test_data, test_label) = CIFAR100()

        train_label_multi_label = to_multi_label(train_label, num_label=NUM_CLASSES)
        val_label_multi_label = to_multi_label(val_label, num_label=NUM_CLASSES)
        test_label_multi_label = to_multi_label(test_label, num_label=NUM_CLASSES)

        trainSet = ImageDataset(train_data, train_label_multi_label, image_size=IMG_SIZE, crop_size=IMG_SIZE - 2)
        valSet = ImageDataset(val_data, val_label_multi_label, image_size=IMG_SIZE, crop_size=IMG_SIZE - 2,
                               mode='test')
        testSet = ImageDataset(test_data, test_label_multi_label, image_size=IMG_SIZE, crop_size=IMG_SIZE - 2,
                               mode='test')

        trainloader = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.num_workers,
                                                  pin_memory=False, drop_last=True)
        if args.methodOption != 'auc':
            trainloader2 = torch.utils.data.DataLoader(trainSet, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.num_workers,
                                                       pin_memory=False, drop_last=True)
        valloader = torch.utils.data.DataLoader(valSet, batch_size=BATCH_SIZE, shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=False)
        testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, shuffle=False, num_workers=args.num_workers,
                                                 pin_memory=False)
    else:
        print('dataset error!')

    ###### Model, SEED, losses setting ######
    if args.dataset == 'chexpert':
        model = DenseNet121(pretrained=True, last_activation=None, activations='relu', num_classes=13)
    elif args.dataset == 'celeba':
        model = resnet18(num_classes=40)
    elif args.dataset == 'cifar100':
        model = resnet18(num_classes=100)
    model = model.cuda()

    set_all_seeds(args.SEED)

    Loss_auc = AUCLoss_multiLabel(imratio=imratio, m=args.margin)
    Loss_ce = CrossEntropyBinaryLoss_MultiLabel(num_classes=NUM_CLASSES)

    ###### Initials ######
    train_loss_list, train_auc_list, val_auc_list, lr_list, iter_list, best_val_auc_list = [], [], [], [], [], []
    test_auc_list, best_test_auc_list = [], []

    alpha = Variable(torch.zeros(NUM_CLASSES).type(dtype), requires_grad=False)
    a = Variable(torch.zeros(NUM_CLASSES).type(dtype), requires_grad=False)
    b = Variable(torch.zeros(NUM_CLASSES).type(dtype), requires_grad=False)
    z_a = Variable(torch.zeros(NUM_CLASSES).type(dtype), requires_grad=False)
    z_b = Variable(torch.zeros(NUM_CLASSES).type(dtype), requires_grad=False)

    u_weights = list(model.parameters())
    w_weights = copy.deepcopy(u_weights)

    z_w_list = []
    for (name, w) in model.named_parameters():
        z_w_list.append(torch.zeros_like(w))

    lr = args.lr
    beta = args.beta
    beta_ct = args.beta_ct
    decay_flag = 1
    epoch_list = []

    label_set = np.linspace(0, NUM_CLASSES - 1, NUM_CLASSES).astype(int)
    train_pred, train_true = [], []
    best_val_auc, best_test_auc = 0, 0
    epoch_count, iter_count = 0, 0
    if args.methodOption == 'auc':
        batch_perIte = 1
    elif args.methodOption == 'ct':
        batch_perIte = args.inner_update_steps+1
    else:
        print('Error in batch_perIte definition')

    trainloader_copy = iter(trainloader)
    if args.methodOption != 'auc':
        trainloader2_copy = iter(trainloader2)

    while epoch_count <= args.total_epoch:
        iter_count += 1
        epoch_count = (batch_count * args.BATCH_SIZE_perTask) / num_train_samples

        model.train()

        if (epoch_count > args.decay_point) and (decay_flag == 1):

            decay_flag = 0

            print('Weight decay applying')
            print('From lr={0}, beta={1}'.format(lr, beta))
            lr = lr * args.decay_rate
            beta = beta * args.decay_rate
            beta_ct = beta_ct * args.decay_rate
            print('To lr={0}, beta={1}'.format(lr, beta))

        np.random.shuffle(label_set)
        selectTasks = label_set[:args.task_BATCH_SIZE]

        if args.methodOption == 'ct':
            for k in range(args.inner_update_steps):
                try:
                    data_in, targets_in = trainloader2_copy.next()
                except:
                    trainloader2_copy = iter(trainloader2)
                data_in, targets_in = data_in.cuda(), targets_in.cuda()

                loss_ce = 0
                for idx_label in range(args.task_BATCH_SIZE):
                    data_in_perTask = data_in[idx_label * args.BATCH_SIZE_perTask : (idx_label+1) * args.BATCH_SIZE_perTask]
                    targets_in_perTask = targets_in[idx_label * args.BATCH_SIZE_perTask: (idx_label + 1) * args.BATCH_SIZE_perTask]

                    outputs_in = model(data_in_perTask)
                    loss_ce += Loss_ce(outputs_in, targets_in_perTask, selectTasks=[selectTasks[idx_label]])

                loss_ce = loss_ce / args.task_BATCH_SIZE
                zero_grad(model)
                grads_ce = torch.autograd.grad(loss_ce, model.parameters(), retain_graph=False)
                for u, g, (name, w_model), w in zip(u_weights, grads_ce, model.named_parameters(),
                                                        w_weights):
                    u.data = (1 - lr) * u.data + lr * (w.data - beta_ct * g.data)
                    w.data = u.data
                    w_model.data = u.data

        grads_a, grads_b, grads_alp, loss_auc = 0, 0, 0, 0

        try:
            data, targets = trainloader_copy.next()
        except:
            trainloader_copy = iter(trainloader)
        data, targets = data.cuda(), targets.cuda()

        for idx_label in range(args.task_BATCH_SIZE):
            selectTask = selectTasks[idx_label]
            data_perTask = data[idx_label * args.BATCH_SIZE_perTask: (idx_label + 1) * args.BATCH_SIZE_perTask]
            targets_perTask = targets[idx_label * args.BATCH_SIZE_perTask: (idx_label + 1) * args.BATCH_SIZE_perTask]
            outputs_i = model(data_perTask)
            y_pred_i = torch.sigmoid(outputs_i)

            grads_a += Loss_auc.g1_grad_a(y_pred_i, a, targets_perTask, task=selectTask)
            grads_b += Loss_auc.g1_grad_b(y_pred_i, b, targets_perTask, task=selectTask)
            grads_alp += Loss_auc.g2(y_pred_i, targets_perTask, task=selectTask) - Loss_auc.g3_grad(alpha,
                                                                                                  task=selectTask)
            loss_auc += Loss_auc.g1(y_pred_i, a, b, targets_perTask, task=selectTask) \
                       + alpha[selectTask] * Loss_auc.g2(y_pred_i, targets_perTask, task=selectTask) \
                       - Loss_auc.g3(alpha, task=selectTask)

            if idx_label == 0:
                y_pred = torch.clone(y_pred_i)
            else:
                y_pred = torch.cat((y_pred, y_pred_i), 0)

        grads_a = grads_a / args.task_BATCH_SIZE
        grads_b = grads_b / args.task_BATCH_SIZE
        grads_alp = grads_alp / args.task_BATCH_SIZE
        loss_auc = loss_auc / args.task_BATCH_SIZE

        # a updates
        z_a.data = (1 - beta) * z_a + beta * grads_a
        a.data = a - lr * z_a

        # b updates
        z_b.data = (1 - beta) * z_b + beta * grads_b
        b.data = b - lr * z_b

        # alpha updates
        alpha.data = alpha.data + lr * grads_alp
        alpha.data = torch.clamp(alpha.data, 0, 999)

        # w updates
        zero_grad(model)
        grads_auc = torch.autograd.grad(loss_auc, model.parameters(), retain_graph=False)

        for g, w, z_w, (name, w_model), u in zip(grads_auc, w_weights, z_w_list,
                                                               model.named_parameters(),
                                                               u_weights):
            z_w.data = (1 - beta) * z_w + beta * g.data
            w.data = w.data - lr * z_w
            u.data = w.data
            w_model.data = w.data

        train_pred.append(y_pred.cpu().detach().numpy())
        train_true.append(targets.cpu().detach().numpy())

        if (args.dataset == 'chexpert') or (args.dataset == 'celeba'):
            val_period = 500
        elif args.dataset == 'cifar100':
            val_period = 2500

        if iter_count % val_period == 0:

            train_true = np.concatenate(train_true)
            train_pred = np.concatenate(train_pred)
            train_auc = auc_score(train_true, train_pred)
            train_pred = []
            train_true = []

            # evaluations
            model.eval()

            # Validation
            val_auc = evaluate(valloader, model)
            if val_auc > best_val_auc:
                best_val_auc = val_auc

            # Testing
            test_auc = evaluate(testloader, model)
            if test_auc > best_test_auc:
                best_test_auc = test_auc

            # save to log
            train_auc_list.append(train_auc)
            val_auc_list.append(val_auc)
            test_auc_list.append(test_auc)
            best_val_auc_list.append(best_val_auc)
            best_test_auc_list.append(best_test_auc)
            iter_list.append(iter_count)
            epoch_list.append(epoch_count)

            # print results
            print("Epoch: {}, train_auc: {:4f}, test_auc: {:4f}, best_test_auc: {:4f}".format(epoch_count, train_auc, val_auc, best_val_auc))

    output_dict = {'train_auc_list': train_auc_list,
                   'val_auc_list': val_auc_list,
                   'best_val_auc_list': best_val_auc_list,
                   'iter_list': iter_list,
                   'epoch_list': epoch_list
                   }

    return output_dict

if __name__ == '__main__':
    main()