import argparse_utils
import jax_utils as ju
from jax_utils import load_from_cache_or_compute, get_loss, append
import numpy as onp
import jax
import jax.random as random
import jax.numpy as np
import neural_tangents as nt
from data import get_datasets
import os
from pathlib import Path

def get_predictive_variance(kernel_fn, emp_kernel_fns, train_data, dloader, mode='ntk'):
    test_data = dloader.dataset[0]
    test_data = np.array(test_data.permute([0,2,3,1]).numpy())

    trtr_ntk=kernel_fn(train_data, train_data, mode)
    trte_ntk=kernel_fn(train_data, test_data, mode)

    trtr_nngp=kernel_fn(train_data, train_data, 'nngp')
    trte_nngp=kernel_fn(train_data, test_data, 'nngp')
    tete_nngp=kernel_fn(test_data, test_data, 'nngp')

    inv_ntk = jax.numpy.linalg.inv(trtr_ntk)

    first = 0
    second = trtr_nngp @ inv_ntk @ trte_ntk + trte_nngp
    for emp_kernel_fn in emp_kernel_fns:
        trtr_emp=emp_kernel_fn(train_data, train_data, mode)
        trte_emp=emp_kernel_fn(train_data, test_data, mode)
        first += (trte_emp.T + trte_ntk.T @ inv_ntk @ trtr_emp) @ (inv_ntk@(trtr_ntk-trtr_emp)@inv_ntk)

    return -2*jax.numpy.diag(first@second / len(emp_kernel_fn))


def get_ntk_name(specialist, data):
    return "spec{}_{}_ntk".format(specialist, data)

def average_kernel(first_specialist, num_specialist, tr_data, te_data, emp_k_fn, mode="ntk", data="foo"):
    K = 0
    print("Computing average Kernel")
    for specialist in range(first_specialist, first_specialist + num_specialist):
        print(specialist)
        key = random.PRNGKey(specialist)
        _, params = init_fn(key, input_shape=train_data.shape)
        curr_K = load_from_cache_or_compute(dir_name + "/d{}_w{}".format(hidden_depth, hidden_width),
                                            get_ntk_name(specialist, data), emp_k_fn, te_data, tr_data,
                                            mode, params)
        K += curr_K / num_specialist
    return K

def get_ntk_list(specialist, params, train_dloader, ind_dloader, ood_dloaders):
    ntk_tetrs=[]
    for idx, dloader in enumerate([train_dloader, ind_dloader] + ood_dloaders):
        data = (["train", "test"]+["ood{}".format(ii) for ii in range(len(ood_dloaders))])[idx]
        test_data=np.array(dloader.dataset[0].permute([0,2,3,1]).numpy())
        ntk_tetrs.append(load_from_cache_or_compute(dir_name+ "/d{}_w{}".format(hidden_depth, hidden_width),
                                                           get_ntk_name(specialist, data),
                                                           emp_kernel_fn, test_data, train_data, "ntk", params))
    return ntk_tetrs

if __name__ == '__main__':
    parser = ju.get_argparse()
    arg = parser.parse_args()
    run_id = parser.print_arg(arg)
    dir_name=arg.out_dir+"/"+run_id
    print(dir_name)

    dataset=arg.dataset
    num_train_data=arg.num_train_data
    num_ensemble=arg.num_ensemble

    bias=arg.bias if arg.bias >=0 else None
    activation=arg.activation

    hidden_depths = [int(d) for d in arg.hidden_depths.split(",")]
    hidden_widths = [int(w) for w in arg.hidden_widths.split(",")]

    train_dloader, ind_dloader, ood_dloaders, in_shape, out_shape = get_datasets(dataset, num_train_data, num_test_data=num_train_data if arg.num_test_data < 0 else arg.num_test_data, binary=arg.binary)

    train_data, train_labels = train_dloader.dataset
    train_data = np.array(train_data.permute([0,2,3,1]).numpy())
    train_labels = np.array(train_labels.numpy())

    test_data, test_labels = ind_dloader.dataset
    test_data = np.array(test_data.permute([0,2,3,1]).numpy())
    test_labels = np.array(test_labels.numpy())

    dir_name=arg.out_dir+"/"+run_id

    mode='ntk'
    result=dict()
    for hidden_depth in hidden_depths:
        print("Depth", hidden_depth)
        result[hidden_depth]=dict()
        for hidden_width in hidden_widths:
            print("Width", hidden_width)
            print("Training")

            result[hidden_depth][hidden_width]=dict()
            curr_result = result[hidden_depth][hidden_width]

            if arg.net=='mlp':
                init_fn, apply_fn, kernel_fn = ju.get_mlp(out_shape, hidden_width, hidden_depth, bias=arg.bias, activation=arg.activation)
            elif arg.net == 'conv':
                init_fn, apply_fn, kernel_fn = ju.get_miniminiconv(out_shape, hidden_width, hidden_depth, bias=arg.bias, activation=arg.activation)
            elif arg.net == 'wrn':
                assert hidden_width==4 and hidden_depth==4
                init_fn, apply_fn, kernel_fn = ju.get_wrn(out_shape, hidden_width, hidden_depth, bias=arg.bias, activation=arg.activation, data=dataset)

            vars=dict()
            preds=dict()

            emp_kernel_fn = nt.empirical_kernel_fn(apply_fn)
            emp_kernel_fn = nt.batch(emp_kernel_fn, batch_size=arg.batch if arg.batch > 0 else arg.num_train_data)

            key0 = random.PRNGKey(0)
            _, params0 = init_fn(key0, input_shape=train_data.shape)

            inf_ntk_tetrs = []
            emp0_ntk_tetrs = []
            for idx, dloader in enumerate([train_dloader, ind_dloader] + ood_dloaders):
                data = (["train", "test"]+["ood{}".format(ii) for ii in range(len(ood_dloaders))])[idx]
                test_data=np.array(dloader.dataset[0].permute([0,2,3,1]).numpy())
                inf_ntk_tetrs.append(average_kernel(arg.first_specialist, arg.num_ensemble, train_data, test_data,
                                                    emp_kernel_fn, data=data))
                emp0_ntk_tetrs.append(load_from_cache_or_compute(dir_name+ "/d{}_w{}".format(hidden_depth, hidden_width),
                                                                   get_ntk_name(0, data),
                                                                   emp_kernel_fn, test_data, train_data, mode, params0))

            inf_ntk_trtr=inf_ntk_tetrs[0]
            inv_ntk=jax.numpy.linalg.inv(inf_ntk_trtr)
            emp0_ntk_trtr=emp0_ntk_tetrs[0]

            preds["Y"]=[[] for k in range(len(ood_dloaders)+2)]

            init_var=0
            for specialist in range(arg.first_specialist, arg.first_specialist+arg.num_ensemble):
                print(specialist)

                key = random.PRNGKey(specialist)
                _, params = init_fn(key, input_shape=train_data.shape)
                emp_ntk_tetrs = get_ntk_list(specialist, params, train_dloader, ind_dloader, ood_dloaders)

                key_rand = random.PRNGKey(specialist+1000)
                _, params_rand = init_fn(key_rand, input_shape=train_data.shape)
                emp_rand_ntk_tetrs = get_ntk_list(specialist+1000, params_rand, train_dloader, ind_dloader, ood_dloaders)

                emp_ntk_trtr = emp_ntk_tetrs[0]
                emp_rand_ntk_trtr = emp_rand_ntk_tetrs[0]

                init_var += np.linalg.norm(inf_ntk_trtr - emp_ntk_trtr)/np.linalg.norm(inf_ntk_trtr)/arg.num_ensemble

                for i, dloader in enumerate([train_dloader, ind_dloader] + ood_dloaders):
                    test_data=np.array(dloader.dataset[0].permute([0,2,3,1]).numpy())
                    emp_ntk_tetr = emp_ntk_tetrs[i]
                    emp_rand_ntk_tetr = emp_rand_ntk_tetrs[i]

                    y_tr0 = apply_fn(params, train_data)
                    y_te0 = apply_fn(params, test_data)

                    preds["Y"][i].append((emp_rand_ntk_tetr - inf_ntk_tetrs[i] @ inv_ntk @ emp_rand_ntk_trtr) @ ( inv_ntk @ (emp_rand_ntk_trtr-inf_ntk_trtr) - np.eye(inf_ntk_trtr.shape[0])) @ inv_ntk @ (y_tr0 - train_labels) )

                    _, y_te0_relax = nt.predict.gradient_descent_mse(inf_ntk_trtr, train_labels)(None, y_tr0, 0*y_te0, emp_rand_ntk_tetr)
                    _, y_tr0_relax = nt.predict.gradient_descent_mse(inf_ntk_trtr, train_labels)(None, y_tr0, 0*y_tr0, emp_rand_ntk_trtr)

                    _, y_te0_0_relax = nt.predict.gradient_descent_mse(inf_ntk_trtr, 0*train_labels)(None, y_tr0, 0*y_te0, emp_rand_ntk_tetr)
                    _, y_tr0_0_relax = nt.predict.gradient_descent_mse(inf_ntk_trtr, 0*train_labels)(None, y_tr0, 0*y_tr0, emp_rand_ntk_trtr)


                    for (key, y_tr, y_te, ntk_trtr, ntk_tetr) in [("f_lin_a",y_tr0, y_te0,inf_ntk_trtr, inf_ntk_tetrs[i]),
                                                                  ("emp_corr",y_tr0, y_te0,emp_ntk_trtr, emp_ntk_tetr),
                                                                  ("f_lin",y_tr0, y_te0, emp_rand_ntk_trtr, emp_rand_ntk_tetr),
                                                                  # ("sub_approx_decorr", y_tr0_relax, y_te0_relax, inf_ntk_trtr, inf_ntk_tetrs[i]),
                                                                  ("f_lin_i", y_tr0_0_relax, y_te0_0_relax, inf_ntk_trtr, inf_ntk_tetrs[i]),
                                                                  ("f_lin_c",np.zeros(y_tr0.shape), np.zeros(y_te0.shape),emp_ntk_trtr, emp_ntk_tetr)]:
                        if key not in preds:
                            preds[key]=[[] for k in range(len(ood_dloaders)+2)]

                        mse_predictor = nt.predict.gradient_descent_mse(ntk_trtr, train_labels)

                        t = None
                        y_train_t, y_test_t = mse_predictor(t, y_tr, y_te, ntk_tetr)
                        preds[key][i].append(y_test_t)

            for key in preds:
                preds[key] = [np.stack(preds[key][dset], axis=1) for dset in range(len(preds[key]))]

            preds["sub_gt_decorr"] = [vz-ved for (ved, vz) in zip(preds["f_lin_a"], preds["f_lin"])]
            # preds["sub_gt_corr"] = [vz-ved for (ved, vz) in zip(preds["f_lin_a"], preds["f_lin"])]

            def get_vars(preds):
                vars=dict()
                for key in preds:
                    vars[key] = [np.sum(np.var(preds[key][dset], axis=1), axis=1) for dset in range(len(preds[key]))]

                vars["f_lin_cov_gt"] = [2*ju.get_covariance(i,y) for (i,y) in zip(preds["f_lin_a"], preds["sub_gt_decorr"])]
                # vars["covar_gt_corr"] = [2*ju.get_covariance(i,y) for (i,y) in zip(preds["f_lin_a"], preds["sub_gt_corr"])]
                vars["f_lin_cov"] = [2*ju.get_covariance(i,y) for (i,y) in zip(preds["f_lin_a"], preds["Y"])]
                vars["noise_covarr"] = [vz-ved for (ved, vz) in zip(vars["f_lin"], vars["emp_corr"])]
                # vars["error_sub_decorr"] = [vz-ved for (ved, vz) in zip(vars["sub_gt_decorr"], vars["sub_approx_decorr"])]

                foo = [vz-ved for (ved, vz) in zip(vars["f_lin_a"], vars["f_lin"])]
                foo = [vz-ved for (ved, vz) in zip(vars["f_lin_i"], foo)]
                foo = [vz-ved for (ved, vz) in zip(vars["f_lin_c"], foo)]
                vars["residual"] = [vz-ved for (ved, vz) in zip(vars["f_lin_cov"], foo)]
                return vars

            vars = get_vars(preds)

            # Log NTK Norm change
            append(curr_result, "ntk_init_norm", init_var)

            loss_fn=get_loss(arg.loss)
            for key in preds:
                mean_pred = []
                var_pred = []
                for idx, dloader in enumerate([train_dloader, ind_dloader]):
                    mean_pred.append(np.mean(preds[key][idx], axis=1))
                    var_pred.append(vars[key][idx])
                    test_labels=np.array(dloader.dataset[1].numpy())
                    append(curr_result, 'loss_ens_'+key, loss_fn(np.mean(preds[key][idx], axis=1), test_labels))
                    append(curr_result, 'loss_spec_'+key, onp.mean([loss_fn(preds[key][idx][:,i], test_labels) for i in range(arg.num_ensemble)]))
                    append(curr_result, 'acc_ens_'+key, ju.acc_fn(np.mean(preds[key][idx], axis=1), test_labels))
                    append(curr_result, 'acc_spec_'+key, onp.mean([ju.acc_fn(preds[key][idx][:,i], test_labels) for i in range(arg.num_ensemble)]))

            # Predvar and auroc
            for key in vars:
                [append(curr_result, "predvar_"+key, np.mean(v)) for v in vars[key] ]
                for i,v in enumerate(vars[key][2:]):
                    auroc = ju.compute_auroc(onp.array(vars[key][1]), onp.array(v))
                    append(curr_result, "auroc_"+key, auroc)

#            [append(curr_result, "PREDVAR_{}".format(key), np.concatenate(vars[key], axis=0)) for key in vars  if key == "error_taylor"]
#            [append(curr_result, "PRED_{}".format(key), np.mean(np.concatenate(preds[key], axis=0), axis=1)) for key in vars  if key == "f_lin"]
#            [append(curr_result, "PRED_{}".format(key), np.mean(np.concatenate(preds[key], axis=0), axis=1)) for key in vars  if key == "f_lin_a"]

            Path(dir_name).mkdir(parents=True, exist_ok=True)
            with open(os.path.join(dir_name, "ntk_d{}_w{}_ens{}-{}_result.txt".format(hidden_depth, hidden_width, arg.first_specialist, arg.num_ensemble)), 'w') as f:
                for k in curr_result:
                    line="{}\t{}\t{}\t{}".format(hidden_depth, hidden_width, k, "\t".join([str(f) for f in curr_result[k]]))
                    f.write(line+"\n")

    for k1 in result:
        for k2 in result[k1]:
            for k3 in result[k1][k2]:
                print("{}\t{}\t{}\t{}".format(k1,k2,k3, "\t".join([str(f) for f in result[k1][k2][k3]])))

