import torch
import argparse

import torch.utils.data.distributed
import os
import numpy as np
from torchvision import datasets, transforms

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch ImageNet Example',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--tflite_path', type=str, default=None)
parser.add_argument('--dataset', default='imagenet', type=str)
parser.add_argument('--val_dir', default=os.path.expanduser('/dataset/imagenet/val'),
                    help='path to validation data')
parser.add_argument('--batch_size', type=int, default=256,
                    help='input batch size for training')
parser.add_argument('-j', '--workers', default=8, type=int)
parser.add_argument('--eval_threads', type=int, default=32)
args = parser.parse_args()


def get_val_dataset(resolution):
    # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    kwargs = {'num_workers': args.workers, 'pin_memory': True}
    if args.dataset == 'imagenet':
        val_dataset = \
            datasets.ImageFolder(args.val_dir,
                                 transform=transforms.Compose([
                                     transforms.Resize(int(resolution * 256 / 224)),
                                     transforms.CenterCrop(resolution),
                                     transforms.ToTensor(),
                                     # normalize
                                 ]))
    elif args.dataset == 'vww':
        from dataset.vww import VWWDataset
        val_transform = transforms.Compose([
            transforms.Resize([resolution, resolution]),
            transforms.ToTensor(),
            # normalize,
        ])
        val_dataset = VWWDataset(split='minival', transform=val_transform, dataset_dir=args.val_dir)
    elif args.dataset == 'gcommands':
        from dataset.gcommands import GoogleCommands
        val_dataset = GoogleCommands(args.val_dir, 'test', resolution=resolution)
    else:
        raise NotImplementedError
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size,
        shuffle=False, **kwargs)
    return val_loader


def eval_image(data):
    interpreter = tf.lite.Interpreter(model_path=args.tflite_path)
    interpreter.allocate_tensors()
    # get input & output tensors
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    tf.logging.info('input details: {}'.format(input_details))
    tf.logging.info('output details: {}'.format(output_details))

    input_shape = input_details[0]['shape']

    image, target = data
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    image = image.permute(0, 2, 3, 1)
    image_np = image.cpu().numpy()
    image_np = (image_np * 255 - 128).astype(np.int8)
    target_np = target.cpu().numpy()
    interpreter.set_tensor(
        input_details[0]['index'], image_np.reshape(*input_shape))
    interpreter.invoke()
    output_data = interpreter.get_tensor(
        output_details[0]['index'])
    this_gt = target_np
    pred = np.argmax(output_data.reshape(-1))
    if pred == this_gt:
        return 1
    else:
        return 0


def main():
    interpreter = tf.lite.Interpreter(model_path=args.tflite_path)
    interpreter.allocate_tensors()

    # get input & output tensors
    input_details = interpreter.get_input_details()
    input_shape = input_details[0]['shape']
    resolution = input_shape[1]

    val_loader = get_val_dataset(resolution)
    val_loader_cache = [v for v in val_loader]
    images = torch.cat([v[0] for v in val_loader_cache], dim=0)
    targets = torch.cat([v[1] for v in val_loader_cache], dim=0)

    val_loader_cache = [[x, y] for x, y in zip(images, targets)]
    print(' * Dataset size:', len(val_loader_cache))

    from multiprocessing import Pool
    from tqdm import tqdm
    p = Pool(args.eval_threads)
    correctness = []

    t = tqdm(p.imap_unordered(eval_image, val_loader_cache), total=len(val_loader_cache),
             desc='Evaluating...')
    for idx, ret in enumerate(t):
        correctness.append(ret)
        t.set_description("Evaluating...Acc: {:.2f}".format(sum(correctness) * 100. / len(correctness)))
    final_acc = sum(correctness) * 100. / len(correctness)
    print(' * Final accuracy: {:.2f}%'.format(final_acc))


if __name__ == '__main__':
    main()
