import jax
import jax.numpy as np

from jax import jit, grad, value_and_grad, vmap, pmap
from jax import random

import haiku as hk

from jax.tree_util import tree_multimap, tree_map
from jax.experimental import optimizers

from regularizers_copy import input_jacobian_regularized_net_loss, gradient_penalty_regularized_net_loss

import numpy as onp
import operator
import time

def apply_grads(params, grads):
    return tree_multimap(lambda p,g: p - g, params, grads)

def net_loss(params, state, rng, inputs, labels, loss, net_apply, regularizer=None):
    outputs, new_state = net_apply(params, state, rng, inputs)
    # assume regularizer is only a function of parameters
    if regularizer:
        return loss(outputs, labels) + regularizer(params), new_state
    else:
        return loss(outputs, labels), new_state

"""
Training utility functions for supervised learning
"""
@jax.partial(jit, static_argnums=(6, 7, 8, 9, 10))
def step_grad_optimizer(i, opt_state, net_state, rng, x, y, loss, get_params, net_apply, regularizer, opt_update):
    p = get_params(opt_state)
    (l, new_net_state), g = value_and_grad(net_loss, has_aux=True)(p, net_state, rng, x, y, loss, net_apply, regularizer)
    return opt_update(i, g, opt_state), new_net_state, l

@jax.partial(pmap, axis_name='i', in_axes=(None, 0, 0, 0, 0, 0), static_broadcasted_argnums=(6,7,8,9,10))
def step_distributed_grad_optimizer(i, opt_state, net_state, rng, x, y, loss, get_params, net_apply, regularizer, opt_update):
    p = get_params(opt_state)
    (l, new_net_state), g = value_and_grad(net_loss, has_aux=True)(p, net_state, rng, x, y, loss, net_apply, regularizer)
    
    g = jax.lax.pmean(g, axis_name='i')

    return opt_update(i, g, opt_state), new_net_state, l

def train_epoch(i, opt_state, net_state, rng, np_ds, loss, get_params, net_apply, opt_update, regularizer=None, distributed=True):
    for x, y in np_ds:
        rng, cur_rng = random.split(rng, 2)
        x, y = np.array(x), np.array(y)
        t = time.time()
        if distributed:
            n_devices = x.shape[0]
            cur_rng = random.split(cur_rng, n_devices)
            opt_state, net_state, train_loss = step_distributed_grad_optimizer(i, opt_state, net_state, cur_rng, x, y, loss, get_params, net_apply, regularizer, opt_update)         
        else:
            opt_state, net_state, train_loss = step_grad_optimizer(i, opt_state, net_state, cur_rng, x, y, loss, get_params, net_apply, regularizer, opt_update)         

    return opt_state, net_state, train_loss.mean()

def train_epoch_online(i, opt_state, net_state, rng, np_ds, loss, get_params, net_apply, opt_update, regularizer=None, distributed=True):
    all_logits = []
    all_labels = []
    for x, y in np_ds:
        rng, cur_rng = random.split(rng, 2)
        x, y = np.array(x), np.array(y)
        t = time.time()
        if distributed:
            n_devices = x.shape[0]
            cur_rng = random.split(cur_rng, n_devices)
            all_logits.append(pmap(net_apply)(get_params(opt_state), net_state, cur_rng, x)[0][0])
            all_labels.append(y[0])
            opt_state, net_state, train_loss = step_distributed_grad_optimizer(i, opt_state, net_state, cur_rng, x, y, loss, get_params, net_apply, regularizer, opt_update)         
        else:
            all_logits.append(net_apply(get_params(opt_state), net_state, cur_rng, x)[0])
            all_labels.append(y)
            opt_state, net_state, train_loss = step_grad_optimizer(i, opt_state, net_state, cur_rng, x, y, loss, get_params, net_apply, regularizer, opt_update)         

    return opt_state, net_state, train_loss.mean(), all_logits, all_labels
