from __future__ import print_function
from textwrap import indent
from turtle import color
from matplotlib.pyplot import axis
from numpy.lib.function_base import append
import random
from requests import get
from sympy import plotting
import torch,pickle
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import config as cf
from datasets import ImagenetNoise

import torchvision.transforms as transforms
from torchvision import datasets

import argparse
import numpy as np


import matplotlib.pyplot as plt


from utils import get_pretrained_model, check_dir, prepare_dset, maha, \
    get_maha_distance, get_maha_distance_cov, get_relative_maha_distance, MahaDistNormalizer, ranking_loss, \
        get_maha_predict, get_relative_maha_predict, get_gda_posterior
from networks import *
from torch.autograd import Variable
from torch.nn.functional import one_hot, softmax
import torchvision
from metrics.ood_metrics import OOD_METRICS


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
    #  torch.backends.cudnn.deterministic = True
setup_seed(20)

parser = argparse.ArgumentParser(description='Ensemble Training')
# pretrained models setting
parser.add_argument('--maha_file', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('--maha_file_m0', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50')
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--pretrained_model', default='vit', type=str, help='SSL feature map type')
parser.add_argument('--comp_dis', action='store_true', default=False)

parser.add_argument('--model_path', default='', type=str, help='Model trained with in_dataset')
parser.add_argument('--args_path', default='', type=str, help='Arguments for training model')

parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--batch_size', default=500, type=int)
parser.add_argument('--dataset', default='cifar10', type=str, help='cifar10/cifar100')
parser.add_argument('--num_classes', default=10, type=int)


parser.add_argument('--ynoise_type', default='symmetric', type=str, help='symmetric/pairflip')
parser.add_argument('--ynoise_rate', default=0.0, type=float, help='label noise rate')
parser.add_argument('--xnoise_type', default='blur', type=str, help='gaussian/blur')
parser.add_argument('--xnoise_arg', default=1, type=float)
parser.add_argument('--xnoise_rate', default=0.0, type=float)
parser.add_argument('--trigger_size', type=int, default=3)
parser.add_argument('--trigger_ratio', type=float, default=0.)


parser.add_argument('--random_state', type=int, default=0)
args = parser.parse_args()

print(args)
# os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
# Hyper Parameter settings
use_cuda = torch.cuda.is_available()
best_acc = 0
batch_size, optim_type = args.batch_size, cf.optim_type

if args.dataset != 'imagenet':
    trainset, testset, trainvalset = prepare_dset(args)
    num_classes = trainset.nb_classes
else:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    trainset = ImagenetNoise(
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        xnoise_rate=args.xnoise_rate,
        xnoise_arg=args.xnoise_arg,
        xnoise_type=args.xnoise_type,
        ynoise_type=args.ynoise_type,
        ynoise_rate=args.ynoise_rate,
        random_state=args.random_state,
        num_classes=args.num_classes
    )
    num_classes = args.num_classes
    testset = ImagenetNoise(
        train=False,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]),
        num_classes=args.num_classes
    )

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,num_workers=4)
# metricloader = torch.utils.data.DataLoader(trainvalset, batch_size=1000, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)


print('loading the checkpoint')
# Load Model
if args.dataset != 'imagenet':
    checkpoint = torch.load(args.model_path)
    args_load = pickle.load(open(args.args_path, 'rb'))
    net = checkpoint['net'].cuda()
else:
    net = torchvision.models.resnet34(pretrained=True).cuda()

pretrain_model = get_pretrained_model(args)
pretrain_model.cuda()
if not args.arch.startswith('clip'):
    pretrain_model = torch.nn.DataParallel(pretrain_model)
pretrain_model.eval()
net = torch.nn.DataParallel(net)
cudnn.benchmark = True

# load parameters of Gaussian distributed
maha_intermediate_dict = np.load(args.maha_file, allow_pickle='TRUE')
# m0_maha_dict = np.load(args.maha_file_m0, allow_pickle='TRUE')
class_cov_invs = maha_intermediate_dict.item()['class_cov_invs']
print(class_cov_invs[0].shape)
class_means = maha_intermediate_dict.item()['class_means']
cov_invs = maha_intermediate_dict.item()['cov_inv']
means = maha_intermediate_dict.item()['mean']
# cov_invs = m0_maha_dict.item()['cov_inv']
# means = m0_maha_dict.item()['mean']

CORRUPTIONS = [
    'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
    'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
    'brightness', 'contrast', 'elastic_transform', 'pixelate',
    'jpeg_compression'
]
def test(net, test_loader):
  """Evaluate network on given dataset."""
  net.eval()
  total_loss = 0.
  total_correct = 0
  with torch.no_grad():
    for images, targets in test_loader:
      images, targets = images.cuda(), targets.cuda()
      logits = net(images)
      loss = F.cross_entropy(logits, targets)
      pred = logits.data.max(1)[1]
      total_loss += float(loss.data)
      total_correct += pred.eq(targets.data).sum().item()

  return total_loss / len(test_loader.dataset), total_correct / len(
      test_loader.dataset)
def test_c(net, test_data, base_path):
  """Evaluate network on given corrupted dataset."""
  corruption_accs = []
  for corruption in CORRUPTIONS:
    # Reference to original data is mutated
    test_data.data = np.load(base_path + corruption + '.npy')
    test_data.targets = torch.LongTensor(np.load(base_path + 'labels.npy'))

    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True)

    test_loss, test_acc = test(net, test_loader)
    corruption_accs.append(test_acc)
    print('{}\n\tTest Loss {:.3f} | Test Error {:.3f}'.format(
        corruption, test_loss, 100 - 100. * test_acc))

  return np.mean(corruption_accs)


def softmax(x):

    max = np.max(x,axis=0,keepdims=True) #returns max of each row and keeps same dims
    e_x = np.exp(x - max) #subtracts each row with its max value
    sum = np.sum(e_x,axis=0,keepdims=True) #returns sum of each row and keeps same dims
    f_x = e_x / sum
    return f_x
up_sample = nn.Upsample(size=(224,224), mode='bilinear')
file_name = '/cifar100/cifar100_'
file_name = '/'+str(args.dataset)+'/'+str(args.dataset)+'_'
def evaluate():
    maha_dis_list_succ = []
    maha_dis_list_err = []
    maha_dis_list = []
    corr_err_list = []
    corr_err_maha_list = []
    conf_list = []
    correct = 0
    correct_maha = 0
    pre_maha_list = []
    targets_list = []
    pretrain_model.eval()
    net.eval()
    with torch.no_grad():
        # for batch_idx, ((inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
        # for batch_idx, (id,(inputs, xnoisy), (targets, true_tar)) in enumerate(trainloader):
        for batch_idx, (id,inputs,targets) in enumerate(testloader):
        # for batch_idx, (inputs, targets) in enumerate(trainloader):
            if args.dataset != 'imagenet':
                pretrain_inputs = up_sample(inputs).cuda()
            else:
                pretrain_inputs = inputs.cuda()
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            
            if args.arch.startswith('clip'):
                pre_feature = pretrain_model.get_image_features(pretrain_inputs)
            else:
                pre_feature = pretrain_model(pretrain_inputs)
            inputs, targets = Variable(inputs), Variable(targets)
            outputs = net(inputs)
            if args.arch.startswith('hug'):
                pre_feature = pre_feature.logits.cpu().data.numpy()
            else:
                pre_feature = pre_feature.cpu().data.numpy()
            # maha_distance = get_maha_distance(pre_feature,class_cov_invs, class_means, targets.cpu().data.numpy())
            # maha_distance = get_maha_distance_cov(pre_feature,cov_invs, class_means, targets.cpu().data.numpy())
            # predited_maha = get_maha_predict(pre_feature, class_cov_invs, class_means, num_classes)
            # gda_prob = get_relative_maha_predict(pre_feature,cov_invs, class_cov_invs, means, class_means, num_classes)

            # get gda probability
            gda_prob = get_gda_posterior(pre_feature, class_cov_invs, class_means, num_classes)
            print(gda_prob.shape)
            # gda_prob = torch.from_numpy(gda_prob)
            # print(F.softmax(gda_prob.data,dim=1))
            # _, predited_maha = torch.min(F.softmax(gda_prob.data,dim=1), 1)
            predited_maha = np.argmin(softmax(gda_prob), axis=0)
            # print(predited_maha)
            print(targets.cpu().data.numpy().shape)

            maha_distance = get_relative_maha_distance(pre_feature,cov_invs, class_cov_invs, means, class_means, targets.cpu().data.numpy())
            maha_dis_list.append(maha_distance)

            
            conf, predicted = torch.max(outputs.data, 1)
            # print(predicted.cpu().data.numpy(),targets.cpu().data.numpy())
            correct_index = (predicted.cpu().data.numpy() == targets.cpu().data.numpy())
            correct_index_maha = (predited_maha == targets.cpu().data.numpy())
            targets_list.append(targets.cpu().data.numpy())
            # print(conf.cpu().data.numpy())
            corr_err_list.append(correct_index)
            corr_err_maha_list.append(correct_index_maha)
            conf_list.append(conf.cpu().data.numpy())
            correct += predicted.eq(targets.data).cpu().sum()
            # print(targets.data.cpu().numpy())

            correct_maha += np.equal(predited_maha,targets.data.cpu().numpy()).sum()
    acc = (correct/10000).item()
    acc_maha = correct_maha/10000
    print(acc, acc_maha)

    # print(corr_err_list)
    targets_list = np.concatenate(targets_list,axis=0).squeeze()
    maha_dis_list = np.concatenate(maha_dis_list,axis=0).squeeze()
    corr_err_list = np.concatenate(corr_err_list,axis=0).squeeze()
    corr_err_maha_list = np.concatenate(corr_err_maha_list,axis=0).squeeze()
    conf_list = np.concatenate(conf_list,axis=0).squeeze()
    print(maha_dis_list.shape)
    print(corr_err_list.shape)
    acc_log = [0 for i in range(100)]
    acc_log_maha = [0 for i in range(100)]
    # for i in range(len(maha_dis_list)):
    #     # print(i)
    #     if corr_err_list[i]:
    #         maha_dis_list_succ.append(maha_dis_list[i])
    #         # print(i)
    #     else:
    #         # print('111')
    #         maha_dis_list_err.append(maha_dis_list[i])

    for i in range(len(conf_list)):
        # print(i)
        if corr_err_list[i]:
            maha_dis_list_succ.append(-conf_list[i])
            # print(corr_err_list[i])
            # print(i)
        else:
            # print('111')
            # print(conf_list[i])
            maha_dis_list_err.append(-conf_list[i])
    
    for i in range(len(targets_list)):
        # print(i)
        if corr_err_list[i]:
            acc_log[targets_list[i]]+=1
        if corr_err_maha_list[i]:
            acc_log_maha[targets_list[i]]+=1
    
    print(acc_log,acc_log_maha)
    # 两组数据
    plt.figure(figsize=(32, 8), dpi=100)
    # plt.subplot(131)
    x = np.arange(len(acc_log))  # x轴刻度标签位置
    width = 0.3  # 柱子的宽度
    # 计算每个柱子在x轴上的位置，保证x轴刻度标签居中
    # x - width/2，x + width/2即每组数据在x轴上的位置
    plt.bar(x - width/2, acc_log, width, label='resnet')
    plt.bar(x + width/2, acc_log_maha, width, label='clip-maha')
    plt.ylabel('ACC')
    # plt.title('clip-maha')
    # x轴刻度标签位置不进行计算
    plt.xticks(x, labels=range(len(acc_log_maha)))
    plt.legend()
    plt.savefig('acc_resnet_clip')

    # plt.figure(figsize=(15, 8), dpi=100)
    # plt.bar(range(len(acc_log)), acc_log)
    # plt.show()
    # plt.savefig('acc_resnet')
    # plt.bar(range(len(acc_log_maha)), acc_log_maha)
    # plt.show()
    # plt.savefig('acc_maha')
    
    return maha_dis_list_succ, maha_dis_list_err


in_confidence_score, ood_confidence_score = evaluate()
# kwargs = dict(histtype='stepfilled', alpha=0.3, bins=40)
# plt.hist(in_confidence_score, **kwargs)
# plt.hist(ood_confidence_score, **kwargs)
# plt.savefig('in_out')
# # in_confidence_score = np.concatenate(in_confidence_score)
# # ood_confidence_score = np.concatenate(ood_confidence_score)
    

# print("Evaluating misclassifction Detection Perfermance...")
# scores = np.concatenate((in_confidence_score, ood_confidence_score), axis=0).astype(np.float128)
    
# in_labels = np.zeros_like(in_confidence_score)
# out_labels = np.ones_like(ood_confidence_score)
# domain_labels = np.concatenate((in_labels, out_labels), axis=0)

# tpr95_score = OOD_METRICS["tpr95"](domain_labels, scores)
# auroc_score = OOD_METRICS["auroc"](domain_labels, scores)
# auprIn_score = OOD_METRICS["auprIn"](domain_labels, scores)
# auprOut_score = OOD_METRICS["auprOut"](domain_labels, scores)
# de_score = OOD_METRICS["detection_err"](domain_labels, scores)
    
# print("{:20}{:13.1f}% ".format("FPR at TPR 95%:", tpr95_score*100))
# print("{:20}{:13.1f}% ".format("Detection error:", de_score*100))
# print("{:20}{:13.1f}% ".format("AUROC:",auroc_score*100))
# print("{:20}{:13.1f}% ".format("AUPR In:",auprIn_score*100))
# print("{:20}{:13.1f}% ".format("AUPR Out:",auprOut_score*100))

# evaluating acc on the clean data and corrupted data
# test_transform = transforms.Compose(
#       [transforms.ToTensor(),
#        transforms.Normalize([0.5] * 3, [0.5] * 3)])
# test_data = datasets.CIFAR100('./data/cifar', train=False, transform=test_transform, download=True)
# test_loader = torch.utils.data.DataLoader(test_data,batch_size=args.batch_size,shuffle=False,num_workers=4,pin_memory=True)
# base_c_path = '/data/cuipeng/CIFAR-100-C/'
# num_classes = 100
# test_loss, test_acc = test(net, test_loader)
# print('Clean\n\tTest Loss {:.3f} | Test Error {:.2f}'.format(test_loss, 100 - 100. * test_acc))

# test_c_acc = test_c(net, test_data, base_c_path)
# print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc))