from __future__ import print_function
from matplotlib.pyplot import axis
from numpy.lib.function_base import append
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import config as cf
from datasets import ImagenetNoise

import torchvision
import torchvision.transforms as transforms

import os
import pickle
import time
import argparse
import datetime
import copy,sys
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torchvision.models as torchvision_models

from pytorch_pretrained_vit import ViT


from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import matplotlib.pyplot as plt


# from metrics import *

from utils import check_dir, prepare_dset, update_print, get_relative_maha_distance, maha, \
    get_pretrained_model, get_maha_distance, MahaDistNormalizer, ranking_loss
from networks import *
from torch.autograd import Variable
from datasets import CIFAR10Noise, CIFAR100Noise
# from metrics.label_metrics import target_mean
from torch.nn.functional import one_hot, softmax
import torchvision
from models import resnet_cifar
from networks.wide_resnet import Wide_ResNet

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
    #  torch.backends.cudnn.deterministic = True
setup_seed(20)

parser = argparse.ArgumentParser(description='Ensemble Training')
# pretrained models setting
parser.add_argument('--maha_file', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet34')
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--pretrained_model', default='vit', type=str, help='SSL feature map type')
parser.add_argument('--comp_dis', action='store_true', default=False)

parser.add_argument('--e_lambda', default=0.1, type=float, help='entropy weight')
parser.add_argument('--warmup', default=-1, type=int)
parser.add_argument('--rank_epo', default=1000, type=int)
parser.add_argument('--loss_type', default='rank0', type=str, help='rank0/margin_rank')
parser.add_argument('--method', default='ce', type=str, help='ce, ls, l1, focal')
parser.add_argument('--rank_weight', default=1.0, type=float, help='ranking loss weight')
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--lr', default=0.1, type=float, help='learning_rate')
parser.add_argument('--gamma', default=0.2, type=float, help='gamma')
parser.add_argument('--net_type', default='resnet', type=str, help='model')
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--batch_size', default=512, type=int)
parser.add_argument('--dataset', default='cifar10', type=str, help='cifar10/cifar100')
parser.add_argument('--milestones', nargs='+',default=[60,120,160], type=int)
parser.add_argument('--num_classes', default=10, type=int)
parser.add_argument('--left', default=1.0, type=float)
parser.add_argument('--right', default=1.0, type=float)
parser.add_argument('--reverse', action='store_true', default=False)
parser.add_argument('--T', default=1.0, type=float)
parser.add_argument('--epsilon_p', default=2.0, type=float)
parser.add_argument('--epsilon',
                    default=1.0,
                    type=float,
                    help='Coefficient of Label Smoothing')

parser.add_argument('--alpha',
                    default=0.05,
                    type=float,
                    help='Coefficient of L1 Norm')

parser.add_argument('--fgamma',
                    default=1.0,
                    type=float,
                    help='Coefficient of Focal Loss')

parser.add_argument('--depth', default=28, type=int, help='depth of model')
parser.add_argument('--widen_factor', default=10, type=int, help='width of model')
parser.add_argument('--dropout', default=0.3, type=float, help='dropout_rate')

parser.add_argument('--ynoise_type', default='symmetric', type=str, help='symmetric/pairflip')
parser.add_argument('--ynoise_rate', default=0.0, type=float, help='label noise rate')
parser.add_argument('--xnoise_type', default='blur', type=str, help='gaussian/blur')
parser.add_argument('--xnoise_arg', default=1, type=float)
parser.add_argument('--xnoise_rate', default=0.0, type=float)
parser.add_argument('--trigger_size', type=int, default=3)
parser.add_argument('--trigger_ratio', type=float, default=0.)

parser.add_argument('--ensemble_num', default=1, type=int, help="number of model to ensemble")
parser.add_argument('--mc_num', default=0, type=int, help="number of sample MC dropout")

parser.add_argument('--random_state', type=int, default=0)
parser.add_argument('--save_model', action='store_true', default=False)
parser.add_argument('--save_data', action='store_true', default=False)
parser.add_argument('--save_period', type=int, default=1)
parser.add_argument('--exp_name', default='ensemble', type=str)
args = parser.parse_args()

print(list(args.milestones),args)
# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
num_epochs, batch_size, optim_type = args.num_epochs, args.batch_size, cf.optim_type
# Custom_Dataset class
class Custom_Dataset(Dataset):
    def __init__(self, x, y, data_set, transform=None):
        self.x_data = x
        self.y_data = y
        self.data = data_set
        self.transform = transform

    def __len__(self):
        return len(self.x_data)

    # return idx
    def __getitem__(self, idx):
        if self.data == 'cifar':
            img = Image.fromarray(self.x_data[idx])
        elif self.data == 'svhn':
            img = Image.fromarray(np.transpose(self.x_data[idx], (1, 2, 0)))

        x = self.transform(img)

        return x, self.y_data[idx], idx
# Data Uplaod
print('\n[Phase 1] : Data Preparation')
# train_transforms = torchvision.transforms.Compose([
#         torchvision.transforms.RandomCrop(32, padding=4),
#         torchvision.transforms.RandomHorizontalFlip(),
#         torchvision.transforms.ToTensor(),
#         torchvision.transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]),
#     ])



if args.dataset != 'imagenet':
    trainset, testset, trainvalset = prepare_dset(args)
    num_classes = trainset.nb_classes
else:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    trainset = ImagenetNoise(
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        xnoise_rate=args.xnoise_rate,
        xnoise_arg=args.xnoise_arg,
        xnoise_type=args.xnoise_type,
        ynoise_type=args.ynoise_type,
        ynoise_rate=args.ynoise_rate,
        random_state=args.random_state,
        num_classes=args.num_classes
    )
    num_classes = args.num_classes
    testset = ImagenetNoise(
        train=False,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]),
        num_classes=args.num_classes
    )
# trainset = torchvision.datasets.CIFAR10(root=os.path.join('/data/LargeData/Regular/cifar'),
#                                      train=True,
#                                      transform=train_transforms,
#                                      download=True)
# trainset = Custom_Dataset(trainset.data,trainset.targets,'cifar', train_transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,num_workers=4)
# metricloader = torch.utils.data.DataLoader(trainvalset, batch_size=1000, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)




print('| Building net type [' + args.net_type + '] * '+ str(args.ensemble_num))
# _, file_name = getNetwork(args, num_classes)

def getNetwork(args):
    if (args.net_type == 'resnet'):
        if args.dataset != 'imagenet':
            net = resnet_cifar.ResNet34
            print('dataset:',args.dataset)
        else:
            net = torchvision.models.resnet34
        model_args = [num_classes]
        if args.mc_num > 1:
            file_name = 'resnet34_deep_ens' + str(args.ensemble_num) + '_mc_dropout'+str(args.mc_num) \
                + '_'+ args.dataset + '_' + args.arch + '_' + args.loss_type
        else:
            file_name = 'resnet34_deep_ens'+str(args.ensemble_num) + '_'+ args.dataset + '_' + args.arch \
                + '_' + args.loss_type
    elif (args.net_type == 'wide_resnet'):
        net = Wide_ResNet
        model_args = [args.depth, args.widen_factor, args.dropout, num_classes]
        if args.mc_num > 1:
            file_name = 'wide-resnet-'+str(args.depth)+'x'+str(args.widen_factor) + '_deep_ens' + str(args.ensemble_num) \
                +'_mc_dropout'+str(args.mc_num) + '_'+ args.dataset + '_' + args.arch + '_' + args.loss_type
        else:
            file_name = 'wide-resnet-'+str(args.depth)+'x'+str(args.widen_factor)+'_deep_ens'+str(args.ensemble_num) \
                + '_'+ args.dataset + '_' + args.arch + '_' + args.loss_type
    else:
        print('Error : Network should be either [LeNet / VGGNet / ResNet / Wide_ResNet')
        sys.exit(0)

    return net, file_name, model_args

net_fn, file_name, model_args = getNetwork(args)

up_sample = nn.Upsample(size=(224,224), mode='bilinear')


pretrain_model = get_pretrained_model(args)

     

if args.ensemble_num > 1:
    net = Ensemble(args.ensemble_num, net_fn, model_args) # use resnet 18
else:
    net = net_fn(*model_args)

# else:
#     if args.dataset != 'imagenet':
#         net = resnet_cifar.ResNet18(num_classes)
#         print('dataset:',args.dataset)
#     else:
#         net = torchvision.models.resnet18(num_classes=num_classes)

# pretrain_model.cuda(args.pre_gpu)
pretrain_model.cuda()
if not args.arch.startswith('clip'):
    pretrain_model = torch.nn.DataParallel(pretrain_model)
pretrain_model.eval()
net.cuda()
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
# criterion = nn.CrossEntropyLoss()
# rank_criterion = ranking_loss()
if args.loss_type.startswith('rank0'):
    rank_criterion = nn.MarginRankingLoss().cuda()
else:
    rank_criterion = ranking_loss().cuda()

# rank_criterion = nn.MarginRankingLoss(margin=1.0/num_classes)
maha_normalizer = MahaDistNormalizer()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,  nesterov=True)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(num_epochs), 0.0008)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=list(args.milestones), gamma=args.gamma)

# load parameters of Gaussian distributed
maha_intermediate_dict = np.load(args.maha_file, allow_pickle='TRUE')
class_cov_invs = maha_intermediate_dict.item()['class_cov_invs']
class_means = maha_intermediate_dict.item()['class_means']
cov_invs = maha_intermediate_dict.item()['cov_inv']
means = maha_intermediate_dict.item()['mean']
# Training
# writer = SummaryWriter(log_dir="runs/result_1", flush_secs=120)
n_iter = 0

class weighted_entropy_ce(nn.Module):
    def __init__(self):
        super(weighted_entropy_ce, self).__init__()
    
    def forward(self,x_input, y_target, weight, e_lambda):
        weight = weight.reshape(-1, 1)
        # print(weight.shape)
        p = F.softmax(x_input)
        # p = p.detach()
        entropy = - torch.sum(p * F.log_softmax(x_input), dim=1).reshape(-1,1)
        # print(entropy.shape)
        # print(rank_input1)
        

        weight_beta = e_lambda * weight
        # weight_1 = torch.ones_like(weight_beta) - weight_beta
        entropy = weight_beta * entropy
        # print(entropy)
        
        x_input = F.log_softmax(x_input, 1)
        y_target = F.one_hot(y_target, num_classes=num_classes)
        loss = - torch.sum( x_input * y_target, 1)
        loss = torch.mean(loss) -  torch.mean(entropy)
        # print(loss)

        return loss

def ECELoss(logits, labels, n_bins = 15):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """

    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    softmaxes = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(labels)
    ece = torch.zeros(1, device=logits.device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Calculated |confidence - accuracy| in each bin
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece

def my_softmax(X):
    X -= X.max()
    X_exp = X.exp()
    max = X_exp.max() + 0.001
    # partition = X_exp.sum(dim=1, keepdim=True)
    #print("X size is ", X_exp.size())
    #print("partition size is ", partition, partition.size())
    return X_exp / max
def train(epoch):
    global n_iter
    net.train()
    net.training = True
    train_loss = 0
    # rank_loss = 0
    correct = 0
    total = 0
    print('\n=> Training Epoch #%d' %(epoch))
    for batch_idx, (_, (inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
    # for batch_idx, (inputs, targets) in enumerate(trainloader):
        target_bi = torch.zeros(inputs.size(0), num_classes).scatter_(1, targets.view(-1,1).long(), 1)
        target_bi = target_bi.cuda()
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = net(inputs)
        if epoch > args.warmup:
            if args.dataset != 'imagenet':
                pretrain_inputs = up_sample(inputs)
                pretrain_inputs = pretrain_inputs.cuda()
            else:
                pretrain_inputs = inputs.cuda()
            with torch.no_grad():
                pretrain_inputs = pretrain_inputs.cuda()
            # pretrain_inputs = pretrain_inputs.cuda(args.pre_gpu)
            with torch.no_grad():
                if args.arch.startswith('clip'):
                    pre_feature = pretrain_model.get_image_features(pretrain_inputs)
                else:
                    pre_feature = pretrain_model(pretrain_inputs)

                maha_distance = get_relative_maha_distance(pre_feature.cpu().data.numpy(),cov_invs, class_cov_invs, means, class_means, targets.cpu().data.numpy())
                # maha_distance = get_maha_distance(pre_feature.cpu().data.numpy(),class_cov_invs, class_means, targets.cpu().data.numpy())
                # maha_distance = get_maha_distance_feature(pre_feature.cpu().data.numpy(),class_cov_invs, class_means)
                # print(maha_distance.shape)
                maha_distance = torch.from_numpy(maha_distance)
                maha_distance_normalized = maha_normalizer.run(maha_distance, -1., 1.)
                # maha_distance_normalized = F.softmax(maha_distance_normalized/args.T,dim=0)
                maha_distance_normalized = my_softmax(maha_distance_normalized/args.T)
                maha_distance_normalized1 = torch.roll(maha_distance_normalized,-1)
                maha_margin = (maha_distance_normalized - maha_distance_normalized1).squeeze()
                # print(maha_margin[0:20])
                # print(maha_normalizer.min,maha_normalizer.max)
                maha_distance1 = torch.roll(maha_distance, -1)
                target_rank = torch.randn(outputs.shape[0])
                target_flag = (maha_distance < maha_distance1)

                for i in range(len(target_flag)):
                    if target_flag[i]:
                        target_rank[i] = 1
                    else:
                        target_rank[i] = -1
            criterion = weighted_entropy_ce().cuda()
            maha_weight = maha_distance_normalized.cuda()
            if args.reverse:
                maha_weight = (2. * torch.ones_like(maha_distance_normalized) - maha_distance_normalized).cuda()
            ce_loss = criterion(outputs, targets, maha_weight, args.e_lambda)
        else:
            criterion = nn.CrossEntropyLoss().cuda()
            if args.method == 'ce':
                ce_loss = criterion(outputs, targets)
            elif args.method == 'ls':
                epsilon = args.epsilon
                target_bi_smooth = (1.0 - epsilon) * target_bi + epsilon/num_classes
                ce_loss = -torch.mean(torch.sum(torch.nn.functional.log_softmax(outputs, dim=1) * target_bi_smooth, dim=1)) ####################Label Smoothing

            elif args.method == 'l1':
                loss_cla = criterion(outputs, targets)
                loss_f1_norm = torch.mean(torch.norm(outputs,p=1,dim=1))
                ce_loss = loss_cla + args.alpha * loss_f1_norm  ########################## L1 Norm

            elif args.method == 'focal':
                target_var = targets.view(-1,1)
                logpt = torch.nn.functional.log_softmax(outputs, dim=1)
                logpt = logpt.gather(1,target_var)
                logpt = logpt.view(-1)
                pt = Variable(logpt.exp().data)
                weights = (1-pt)**(args.fgamma)
                ce_loss = -torch.mean(weights * logpt)   ################################## Focal Loss
            
            elif args.method == 'poly':
                p = F.softmax(outputs)
                x_input = F.log_softmax(outputs, 1)
                y_target = F.one_hot(targets, num_classes=num_classes)
                pt = torch.sum(p * y_target, dim=1)
                ce_loss = - torch.sum( x_input * y_target, 1)
        
                # Pick out the probabilities of the actual class                
                # Compute the plain cross entropy
                # ce_loss = -1 * pt.log()
                
                # Compute the contribution of the poly loss
                # poly_loss = 0
                poly_loss = args.epsilon_p * (1. - pt)
                
                ce_loss = torch.mean(ce_loss + poly_loss)
        
        
        if args.ensemble_num > 1:
            for net_idx in range(args.ensemble_num):
                # calculate ranking loss
                conf, _ = torch.max(F.softmax(outputs[net_idx]), 1)
                rank_input1 = conf
                rank_input2 = torch.roll(conf, -1)
                loss = criterion(outputs[net_idx], targets) + args.rank_weight * \
                    rank_criterion(rank_input1, rank_input2, target_rank.cuda()) # Loss
                loss.backward()  # Backward Propagation
        else:
            loss = ce_loss
            loss.backward()
        optimizer.step() # Optimizer update
        train_loss += loss.item()

        if args.ensemble_num > 1:
            _, predicted = torch.max(F.softmax(torch.stack(outputs),dim=-1).mean(dim=0).data, 1)
        else:
            _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        update_print('| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%'
                %(epoch, num_epochs, batch_idx+1,
                    (len(trainset)//batch_size)+1, loss.item(), 100.*correct/total))
        n_iter += 1
    scheduler.step()


def test_ensemble(epoch):
    global best_acc
    net.eval()
    net.training = False
    test_loss = 0
    correct = 0
    total = 0
    print(f'\nTest Epoch {epoch}')
    with torch.no_grad():
        logits_list = []
        labels_list = []
        for batch_idx, (id, inputs, targets) in enumerate(testloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            logits_list.append(outputs)
            labels_list.append(targets)
            if args.ensemble_num > 1:
                _, predicted = torch.max(F.softmax(torch.stack(outputs),dim=-1).mean(dim=0).data, 1)
            else:
                _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()
        logits = torch.cat(logits_list, 0)
        labels = torch.cat(labels_list,0)
        # print(logits.shape, labels.shape)
        ece = ECELoss(logits, labels).item()
        acc = (correct/total).item()
        print(f'Acc1={round(acc, 4)}',f'ECE={round(ece, 4)}' )
        if acc > best_acc and epoch <=200:
            best_acc = acc 
            print('New Best Model')
            if not args.save_model:
                return
            state = {
                'net': net, 
                'acc': acc,
                'epoch': epoch
            }
            save_point = 'checkpoint'
            check_dir(save_point)
            base_dir = os.path.join(save_point, "deep_ens")
            check_dir(base_dir)
            save_path = os.path.join(base_dir, file_name + '.pkl')
            print('Save Model to', save_path)
            torch.save(state, save_path)

# def save_args():

def test_mcdropput(epoch):
    global best_acc
    net.eval()
    net.training = False
    passes = 1
    enable_dropout(net)
    passes = args.mc_num
    test_loss = 0
    correct = 0
    total = 0
    print(f'\nTest Epoch {epoch}')
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            inputs, targets = Variable(inputs), Variable(targets)
            outputs_mc = torch.zeros(inputs.shape[0], passes * args.ensemble_num, num_classes).cuda()
            for mc_idx in range(passes):
                outputs = net(inputs)
                if args.ensemble_num > 1:
                    for net_idx in range(args.ensemble_num):
                        outputs_mc[:, mc_idx * args.ensemble_num + net_idx, :] = outputs[net_idx]
                else:
                    outputs_mc[:, mc_idx, :] = outputs
            _, predicted = torch.max(torch.mean(outputs_mc, dim=1).data, dim=1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()
        acc = (correct/total).item()
        print(f'Acc1={round(acc, 4)}')
        if acc > best_acc:
            best_acc = acc 
            print('New Best Model')
            if not args.save_model:
                return
            state = {
                'net': net, 
                'acc': acc,
                'epoch': epoch
            }
            save_point = 'checkpoint'
            check_dir(save_point)
            base_dir = os.path.join(save_point, "deep_ens")
            check_dir(base_dir)
            save_path = os.path.join(base_dir, file_name + '.pkl')
            print('Save Model to', save_path)
            torch.save(state, save_path)

def enable_dropout(model):
    """ Function to enable the dropout layers during test-time """
    for m in model.modules():
        if m.__class__.__name__.startswith('Dropout'):
            m.train()





if __name__ == '__main__':
    print('\n[Phase 3] : Training model')
    print('| Training Epochs = ' + str(num_epochs))
    print('| Initial Learning Rate = ' + str(args.lr))
    print('| Optimizer = ' + str(optim_type))


    if args.save_data:
        check_dir('/data/cuipeng/exps')
        exp_name = args.exp_name + '_' + datetime.datetime.now().strftime("%Y%m%d_%H%M")
        exp_root = os.path.join('/data/cuipeng/exps', exp_name)
        if os.path.exists(exp_root):
            os.rmdir(exp_root)
        os.mkdir(exp_root)
        pickle.dump(args, open(os.path.join(exp_root, 'args.pkl'), 'wb'))
        print('Experiment Name:', exp_name)
        
    elapsed_time = 0
    for epoch in range(num_epochs):
        start_time = time.time()
        train(epoch)
        if args.mc_num > 1:
            test_mcdropput(epoch)
        else:
            test_ensemble(epoch)
        # if args.save_data and epoch % args.save_period == 0:
        #     data = get_data()
        #     torch.save(data, os.path.join(exp_root, f'epoch_{epoch}.pkl'))

        epoch_time = time.time() - start_time
        elapsed_time += epoch_time
        print('| Elapsed time : %d:%02d:%02d'  %(cf.get_hms(elapsed_time)))
    print('Best Acc:', best_acc)
    
    args_root = os.path.join('./checkpoint/', 'deep_ens', file_name + '_args.pkl')
    print('Saving args to '+ args_root)
    pickle.dump(args, open(args_root, 'wb'))
