import jax
import wandb
from functools import partial
from jax import random
from jax.lax import top_k
from jax.tree_util import tree_leaves
import jax.numpy as jnp
import optax
from data.loaders import get_lep_loaders, split_batch, get_iter, unsplit_batch, unreplicate
from networks.mlp import MLP
from flax.training.train_state import TrainState
from models.losses import cross_entropy_loss
from metrics.metrics import cosine_push_loss, participation_ratio
from tqdm import tqdm





def eval_step(state, lep_ds, test_lep_ds, epoch, epochs=100, bsz=512, seed=0, num_classes=10, do_offline_lep=False, parallel=False):
    
    # obtain loades of representations with current model
    num_batches = -1 if do_offline_lep else 10

    print('Obtaining representations...')
    batched_ds, batched_proj_ds = get_lep_loaders(bsz, state, lep_ds, parallel, num_batches=num_batches)

    if do_offline_lep:
        print('Obtaining test representations...')
        batched_test_ds, batched_test_proj_ds = get_lep_loaders(2000, state, test_lep_ds, parallel)
        
        print('Training classifier...')
        embedding_lep_results = train_classifier(batched_ds, batched_test_ds, num_classes, epochs, seed)
        proj_lep_results = train_classifier(batched_proj_ds, batched_test_proj_ds, num_classes, epochs, seed)

        test_acc, top5, train_acc, top_5_train, train_loss = embedding_lep_results
        proj_test_acc, proj_top5, proj_train_acc, proj_top_5_train, proj_train_loss = proj_lep_results
    
        
    online_lep_accs, online_lep_top5 = get_online_test_accs(state, test_lep_ds, parallel)

    few_batches = list(next(iter(batched_ds))[0] for _ in range(10))
    z = jnp.concatenate(few_batches, axis=0)
    few_proj_batches = list(next(iter(batched_proj_ds))[0] for _ in range(10))
    projs = jnp.concatenate(few_proj_batches, axis=0)

    cos_push = cosine_push_loss(z)
    pr, eigvals = participation_ratio(z)
    act_norm = jnp.linalg.norm(z, axis=1).mean()

    cos_push_projs = cosine_push_loss(projs)
    pr_projs, eigvals_proj = participation_ratio(projs)
    act_norm_projs = jnp.linalg.norm(projs, axis=1).mean()

    if do_offline_lep:
        metrics = {'offline_lep_test_acc': test_acc, 'offline_lep_top5': top5,
                   'offline_lep_proj_test_acc': proj_test_acc, 'offline_lep_proj_top5': proj_top5,
                   'offline_lep_train_acc': train_acc, 'offline_lep_top_5_train': top_5_train,
                   'offline_lep_proj_train_acc': proj_train_acc, 'offline_lep_proj_top_5_train': proj_top_5_train,
                   'offline_lep_train_loss': train_loss, 'offline_lep_proj_train_loss': proj_train_loss}
    else:
        metrics = {}
    
    metrics['online_lep_test_acc'] = online_lep_accs['embd']
    metrics['online_lep_top5'] = online_lep_top5['embd']
    metrics['online_lep_proj_test_acc'] = online_lep_accs['proj']
    metrics['online_lep_proj_top5'] = online_lep_top5['proj']
    metrics['collapse_metric'] = cos_push
    metrics['participation_ratio'] = pr
    metrics['act_norm'] = act_norm
    metrics['collapse_metric_projs'] = cos_push_projs
    metrics['participation_ratio_projs'] = pr_projs
    metrics['act_norm_projs'] = act_norm_projs
    metrics['min_eigval'] = jnp.min(jnp.nan_to_num(eigvals))
    metrics['mean_eigval'] = jnp.mean(jnp.nan_to_num(eigvals))
    metrics['max_eigval'] = jnp.max(jnp.nan_to_num(eigvals))
    metrics['min_eigval_projs'] = jnp.min(jnp.nan_to_num(eigvals_proj))
    metrics['mean_eigval_projs'] = jnp.mean(jnp.nan_to_num(eigvals_proj))
    metrics['max_eigval_projs'] = jnp.max(jnp.nan_to_num(eigvals_proj))

    eigvals = jnp.sort(jnp.nan_to_num(eigvals))[::-1]
    eigvals_proj = jnp.sort(jnp.nan_to_num(eigvals_proj))[::-1]
    metrics['eigvals'] = eigvals
    metrics['eigvals_proj'] = eigvals_proj

    if state.target_params is not None:
        metrics['dist_with_target'] = jnp.linalg.norm(tree_leaves(state.params)[0] - tree_leaves(state.target_params)[0])

    return metrics




def train_classifier(batched_ds, batched_test_ds, num_classes, epochs, seed):

    classifier = MLP([num_classes], bnorm=[False])
    rng = random.PRNGKey(seed)
    classifier_params = classifier.init(rng, next(iter(batched_ds))[0]).unfreeze()
    num_steps = epochs * len(batched_ds)
    lr_fn = optax.cosine_decay_schedule(1e-3, num_steps, 1e-5)
    opt = optax.adamw(lr_fn, weight_decay=5e-6)
    classifier_state = TrainState.create(apply_fn=classifier.apply, params=classifier_params, tx=opt)

    for _ in tqdm(range(epochs)):
        for batch_lep in batched_ds:
            # convert numpy arrays to jax arrays
            batch_lep = (jnp.array(batch_lep[0]), jnp.array(batch_lep[1]))
            classifier_state = train_step_lep(classifier_state, batch_lep, num_classes)

    train_acc = 0.
    top_5_train = 0.
    train_loss = 0.
    for batch_lep in batched_ds:
        logits = classifier_state.apply_fn(classifier_state.params, batch_lep[0])
        train_acc += jnp.mean(jnp.argmax(logits, axis=1) == batch_lep[1])
        top_5_train += jnp.mean(jnp.equal(top_k(logits, 5)[1], batch_lep[1][:, None]).any(axis=1))
        train_loss += cross_entropy_loss(logits=logits, labels=batch_lep[1], num_classes=num_classes)

    train_acc /= len(batched_ds)
    top_5_train /= len(batched_ds)
    train_loss /= len(batched_ds)

    test_acc = 0.
    top5 = 0.
    for batch_lep in batched_test_ds:
        logits = classifier_state.apply_fn(classifier_state.params, batch_lep[0])
        test_acc += jnp.mean(jnp.argmax(logits, axis=1) == batch_lep[1])
        top5 += jnp.mean(jnp.equal(top_k(logits, 5)[1], batch_lep[1][:, None]).any(axis=1))

    test_acc /= len(batched_test_ds)
    top5 /= len(batched_test_ds)

    return test_acc, top5, train_acc, top_5_train, train_loss


@partial(jax.jit, static_argnums=(2,))
def train_step_lep(classifier_state, batch, num_classes):
    """Train for a single step."""
    def loss_fn(params):
        logits = classifier_state.apply_fn(params, batch[0])
        loss = cross_entropy_loss(logits=logits, labels=batch[1], num_classes=num_classes)
        return loss
    grads = jax.grad(loss_fn)(classifier_state.params)
    # print(f"logits shape: {logits.shape}")
    classifier_state = classifier_state.apply_gradients(grads=grads)
    return classifier_state



def get_online_test_accs(state, test_ds, parallel):
    tot_params = {'params': state.params, 'batch_stats': state.batch_stats}
    if state.direct_pred is not None:
        tot_params['direct_pred'] = state.direct_pred

    if parallel:
        apply_fn = jax.pmap(state.apply_fn, static_broadcasted_argnums=2)
    else:
        apply_fn = state.apply_fn

    # do one forward pass to get the output keys
    x, _ = next(iter(test_ds))
    if parallel:
        _, _, _, logit_batch = state.apply_fn(unreplicate(tot_params), x, False)
    else:
        _, _, _, logit_batch = state.apply_fn(tot_params, x, False)

    size = 0

    # create dictionaries of accuracies and top5 accuracies for each key in the logits
    online_lep_accs = jax.tree_map(lambda x: 0., logit_batch)
    online_lep_top5 = jax.tree_map(lambda x: 0., logit_batch)

    test_it = get_iter(test_ds, parallel)

    print('Iterating over the test dataset...')
    for i, (x, y) in enumerate(test_it):
        _, _, _, logit_batch = apply_fn(tot_params, x, False)
        logit_batch = unsplit_batch(logit_batch) if parallel else logit_batch
        y = unsplit_batch(y) if parallel else y
        size += y.shape[0]

        for k in logit_batch.keys():
            online_lep_accs[k] += jnp.mean(jnp.argmax(logit_batch[k], axis=-1) == y)
            online_lep_top5[k] += jnp.mean(jnp.equal(top_k(logit_batch[k], 5)[1], y[:, None]).any(axis=1))

    online_lep_accs = jax.tree_map(lambda x: x / (i + 1), online_lep_accs)
    online_lep_top5 = jax.tree_map(lambda x: x / (i + 1), online_lep_top5)

    print('Evaluated on {} test samples'.format(size))

    return online_lep_accs, online_lep_top5
