"""Run evaluation for LPE-BSR of CLIP-ViT models."""
import sys
sys.path.append('./') # pylint: disable=wrong-import-position

import os
from argparse import ArgumentParser

import jax
import jax.numpy as jnp
import jaxlib
import numpy as np
import qax
import tensorflow
import tensorflow_datasets
import transformers
from einshard import einshard
from jax_smi import initialise_tracking
initialise_tracking()

from scripts.default import get_args
from scripts.input_pipeline import create_val_iter
from src.quantization import BernoulliSymmetricQuantizedArray
from src.models.clip_vision.default import \
    load_jx_config, load_jx_params
from src.models.clip_vision.modeling import \
    forward_fn as _forward_fn, CLIPVisionInputs
from src.tree_util import load, save


if __name__ == '__main__':

    # ----------------------------------------------------------------------- #
    # Command line arguments
    # ----------------------------------------------------------------------- #
    parser = ArgumentParser()

    parser.add_argument(
        '--model_name', default='openai/clip-vit-large-patch14', type=str,
        help='(default: openai/clip-vit-large-patch14)')

    parser.add_argument(
        '--model_ckpt', default=None, type=str,
        help='load params from checkpoint if specified (default: None)')

    parser.add_argument(
        '--data_name', default='imagenet2012', type=str,
        help='(default: imagenet2012)')
    parser.add_argument(
        '--batch_size', default=64, type=int,
        help='a size of mini-batch for each evaluation step (default: 64)')

    parser.add_argument(
        '--quantization', default='Q5_0', type=str,
        help='apply fake quantization if specified (default: Q5_0)')
    parser.add_argument(
        '--num_samples', default=20, type=int,
        help='the number of samples for ensembling (default: 20)')

    args, print_fn = get_args(
        parser, exist_ok=False, dot_log_file=False,
        libraries=(jax, jaxlib, tensorflow, tensorflow_datasets, transformers))

    # ----------------------------------------------------------------------- #
    # Prepare dataset
    # ----------------------------------------------------------------------- #
    dataset_builder = tensorflow_datasets.builder(args.data_name)
    input_shape = (224, 224, 3)
    shard_shape = (jax.local_device_count(), -1)
    num_classes = 1000

    val_split = 'validation'
    val_iter, val_dataset_size, val_steps_per_epoch = create_val_iter(
        dataset_builder, args.batch_size, shard_shape, split=val_split)
    log_str = (
        f'It will go through {val_steps_per_epoch} steps to handle '
        f'{val_dataset_size} validation data.')
    print_fn(log_str)

    # ----------------------------------------------------------------------- #
    # Load model
    # ----------------------------------------------------------------------- #
    config = None # pylint: disable=invalid-name
    params = None # pylint: disable=invalid-name

    if args.model_name in (
            'openai/clip-vit-large-patch14',
            'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
        ):
        config = load_jx_config(args.model_name)
        params = load_jx_params(args.model_name)
        head = jnp.load(f'./scripts/{args.model_name}.npy')

        if args.model_ckpt:
            with jax.default_device(jax.devices('cpu')[0]):
                if os.path.exists(
                    os.path.join(args.model_ckpt, 'position_mean.pickle')):
                    params = load(
                        os.path.join(args.model_ckpt, 'position_mean.pickle'))
                else:
                    params = load(args.model_ckpt).position

    if config is None:
        raise NotImplementedError(f'Unknown args.model_name={args.model_name}')

    # ----------------------------------------------------------------------- #
    # Setup model
    # ----------------------------------------------------------------------- #
    IMAGE_MEAN = jnp.array([[[[0.48145466, 0.45782750, 0.40821073]]]])
    IMAGE_STD = jnp.array([[[[0.26862954, 0.26130258, 0.27577711]]]])

    @qax.use_implicit_args
    def forward_fn(params, images): # pylint: disable=redefined-outer-name
        """Returns logit vector for each instance."""
        input_pixels = (images / 255.0 - IMAGE_MEAN) / IMAGE_STD
        inputs = CLIPVisionInputs(input_pixels=input_pixels)
        output = _forward_fn(params, inputs, config).proj_hidden_states
        output = output / jnp.linalg.norm(output, axis=-1, keepdims=True)
        return output @ head

    # ----------------------------------------------------------------------- #
    # Evaluating model
    # ----------------------------------------------------------------------- #
    def accuracy(probs, labels):
        """Computes classification accuracy."""
        return float(jnp.mean(jnp.equal(jnp.argmax(probs, axis=-1), labels)))

    def categorical_negative_log_likelihood(probs, labels, eps=1e-12):
        """Computes categorical negative log-likelihood."""
        return float(jnp.mean(jnp.negative(jnp.sum(jax.nn.one_hot(
            labels, probs.shape[-1]) * jnp.log(probs + eps), axis=-1))))

    def expected_calibration_error(probs, labels, n_bins=15):
        """Computes expected calibration error."""
        bins = [[] for _ in range(n_bins)]
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        for prob, label in zip(probs, labels):
            max_prob = np.max(prob)
            for i in range(n_bins):
                if bin_boundaries[i] < max_prob <= bin_boundaries[i+1]:
                    break
            bins[i].append([np.equal(np.argmax(prob), label), max_prob])
        ece = 0.0
        for i in range(n_bins):
            if len(bins[i]) == 0:
                continue
            b = np.array(bins[i]).mean(0)
            ece += np.abs(b[1] - b[0]) * len(bins[i]) / len(probs)
        return ece

    def _make_val_predictions(_params):
        logits = []
        labels = []
        for _ in range(val_steps_per_epoch):
            batch = next(val_iter)
            logits.append(jax.device_put(
                forward_fn(_params, batch['images'].reshape(-1, 224, 224, 3)
                ).reshape(-1, num_classes), jax.devices('cpu')[0]))
            labels.append(jax.device_put(
                batch['labels'].reshape(-1), jax.devices('cpu')[0]))
        with jax.default_device(jax.devices('cpu')[0]):
            logits = jnp.concatenate(logits)[:val_dataset_size]
            labels = jnp.concatenate(labels)[:val_dataset_size]
            return logits, labels

    all_ens_probs = None
    all_ens_logit = None
    for sample_idx in range(args.num_samples):

        SKIP_PATTERNS = ('embeddings', 'projection')

        tree = jax.tree_util.tree_structure(params)
        keys = jax.random.PRNGKey(sample_idx)
        keys = jax.tree_util.tree_unflatten(
            tree, jax.random.split(keys, tree.num_leaves))

        if args.quantization in [f'Q{e}_0' for e in range(3, 9)]:
            BITS = int(args.quantization[1])
            def _quantizer(path, param, key):
                if param.ndim < 2:
                    return param
                if any(isinstance(e1, jax.tree_util.DictKey) and any(
                        e2 in e1.key for e2 in SKIP_PATTERNS) for e1 in path):
                    return param
                qaram = BernoulliSymmetricQuantizedArray.quantize(
                    key, param, bits=BITS,
                    contraction_axis=0, group_size=1)
                return qaram.materialize()

        qarams = jax.tree_util.tree_map_with_path(_quantizer, params, keys)
        qarams = jax.tree_util.tree_map(
            lambda e: einshard(e, '... O -> ... O*'), qarams)
        logits, labels = _make_val_predictions(qarams)

        with jax.default_device(jax.devices('cpu')[0]):

            if args.save:
                save(os.path.join(
                    args.save, f'val_logits_{sample_idx:03d}'), logits)
                if sample_idx == 0:
                    save(os.path.join(args.save, f'val_labels'), labels)

            if sample_idx == 0:
                all_ens_probs = jax.nn.softmax(logits)
                all_ens_logit = logits
            else:
                all_ens_probs = all_ens_probs * sample_idx
                all_ens_probs = all_ens_probs + jax.nn.softmax(logits)
                all_ens_probs = all_ens_probs / (sample_idx + 1)
                all_ens_logit = all_ens_logit * sample_idx
                all_ens_logit = all_ens_logit + logits
                all_ens_logit = all_ens_logit / (sample_idx + 1)

            val_summarized = {
                'ind/acc': accuracy(
                    jax.nn.softmax(logits), labels),
                'ind/nll': categorical_negative_log_likelihood(
                    jax.nn.softmax(logits), labels),
                'ind/ece': expected_calibration_error(
                    jax.nn.softmax(logits), labels),
                'ens/acc': accuracy(
                    all_ens_probs, labels),
                'ens/nll': categorical_negative_log_likelihood(
                    all_ens_probs, labels),
                'ens/ece': expected_calibration_error(
                    all_ens_probs, labels),
                'ens_logit/acc': accuracy(
                    jax.nn.softmax(all_ens_logit), labels),
                'ens_logit/nll': categorical_negative_log_likelihood(
                    jax.nn.softmax(all_ens_logit), labels),
                'ens_logit/ece': expected_calibration_error(
                    jax.nn.softmax(all_ens_logit), labels)}
            log_str = 'Validation: ' + ', '.join(
                f'{k} {v:.3e}' for k, v in val_summarized.items())
            print_fn(log_str)
