import sys
import os
import numpy as np
import scipy
import torch
import torchvision.models
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from utils_data import *
from utils_model import *
from utils_algo import *
import utils_data
import utils_algo
import utils_model
import warnings
import argparse

warnings.filterwarnings("ignore")


np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

parser = argparse.ArgumentParser(
    prog='Demo file for consistent and calibrated L2D with softmax parameterization',
    usage='Demo file',
    description='A simple demo file with CIFAR-100 dataset.',
    epilog='end',
    add_help=True)

parser.add_argument('-lr', '--learning_rate', help='optimizer\'s learning rate', default=1e-1, type=float)
parser.add_argument('-bs', '--batch_size', help='batch_size', default=128, type=int)
parser.add_argument('-e', '--epochs', help='number of epochs', type=int, default=200)
parser.add_argument('-wd', '--weight_decay', help='weight decay', default=5e-4, type=float)
parser.add_argument('-ex', '--expert', help='expert classes', default=60, type=int)




args = parser.parse_args()
T = args.expert
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
model = []

torch.distributed.init_process_group(backend="nccl")
train_loader, test_loader, train_sampler = prepare_CIFAR100_data(batch_size=args.batch_size)
    
model = WideResNet(28, 101, 4, dropRate=0)

torch.cuda.set_device(int(os.getenv('LOCAL_RANK', -1)))
device=torch.device("cuda", int(os.getenv('LOCAL_RANK', -1)))

num_gpus = torch.cuda.device_count()
if num_gpus >= 1:
    print('use {} gpus!'.format(num_gpus))
    model = nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[int(os.getenv('LOCAL_RANK', -1))])
    
expert = synth_expert(k=20*(T+1), n_classes=101)

optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=0.9, nesterov=True,weight_decay=args.weight_decay)

scaler = torch.cuda.amp.GradScaler(enabled=True)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader) * args.epochs)
for epoch in range(1, args.epochs+2):
    train_sampler.set_epoch(epoch=epoch)
    
    for i, (data, label) in enumerate(train_loader):
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
            data = data.cuda()
            labels = label.cuda()
            optimizer.zero_grad()
            output = model(data).cuda()
            loss = Asym_CE(output=output, label=labels, expert=expert)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()



    model2 = WideResNet(28, 101, 4, dropRate=0)
    model2.load_state_dict(model.module.state_dict())

    if torch.cuda.current_device()!=0:
        torch.distributed.barrier()
    
    if torch.cuda.current_device()==0:
        
        err, coverage, ECE_classifier, ECE_expert, ECE_total, err10, err20, err30 = check_01d(loader=test_loader, model=model2, expert=expert)
        err = err.to("cpu")
        err10 = err10.to("cpu") 
        err20 = err20.to("cpu") 
        err30 = err30.to("cpu") 

        coverage = coverage.to("cpu")
        ECE_classifier, ECE_expert, ECE_total = ECE_classifier.to("cpu"), ECE_expert.to("cpu"), ECE_total.to("cpu")
        
        print('Epoch: {}. system error: {:.4f}. coverage: {:.4f}. ECE_classifier: {:.4f}. ECE_expert: {:.4f}. ECE_total: {:.4f}.'.format(epoch, err, coverage, ECE_classifier, ECE_expert, ECE_total))
        print('Budgeted-10% error: {:.4f}. Budgeted-20% error: {:.4f}. Budgeted-30% error: {:.4f}.'.format(err10, err20, err30))
        print('')
    
    if torch.cuda.current_device()==0:
        torch.distributed.barrier()    
