import argparse

import jax
import jax.numpy as np

from jax import grad, jit, vmap, pmap, value_and_grad
from jax import random

from jax.tree_util import tree_multimap, tree_map
from utils import optimizers
from utils import adaptation_utils
from utils.regularizers import weighted_parameter_loss
import haiku as hk

import numpy as onp

import tensorflow_datasets as tfds
import tensorflow as tf

# hides GPUs, maybe causing issues with memory with Jax and TF
tf.config.set_visible_devices([], 'GPU')

from jax.config import config

import os
import requests

import pickle
import time

from models.util import get_model

from utils.training_utils import train_epoch, train_epoch_online
from utils.eval import eval_ds_all

from utils.losses import nll, accuracy, entropy, brier, ece
from utils.misc import get_single_copy, manual_pmap_tree

from posteriors.utils import sample_weights_diag
from posteriors.swag import init_swag, update_swag, collect_posterior

import resource
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
parser = argparse.ArgumentParser(description='Runs basic train loop on a supervised learning task')
parser.add_argument(
    "--dir",
    type=str,
    default=None,
    required=False,
    help="Training directory for logging results"
)
parser.add_argument(
    "--log_prefix",
    type=str,
    default=None,
    required=False,
    help="Name prefix for logging results"
)
parser.add_argument(
    "--data_dir",
    type=str,
    default='datasets',
    required=False,
    help="Directory for storing datasets"
)
parser.add_argument(
    "--seed",
    type=int,
    default=0,
    required=False
)
parser.add_argument(
    "--wd",
    type=float,
    default=0.,
    required=False
)
parser.add_argument(
    "--model",
    type=str,
    default="ResNet50",
    required=False,
    help="Model class"
)
parser.add_argument(
    "--corruption_type",
    type=str,
    default="brightness",
    required=False,
)
parser.add_argument(
    "--corruption_level",
    type=int,
    default=1,
    required=False,
)
parser.add_argument(
    "--n_epochs",
    type=int,
    default=1,
    required=False,
)
parser.add_argument(
    "--batch_size",
    type=int,
    default=64,
    required=False,
)
parser.add_argument(
    "--lr",
    type=float,
    default=0.00025,
    required=False,
)
parser.add_argument(
    "--adapt_bn_only",
    dest="adapt_bn_only",
    action='store_true'
)
parser.add_argument(
    "--use_swag_posterior",
    dest="use_swag_posterior",
    action='store_true'
)
parser.add_argument(
    "--use_data_augmentation",
    dest="use_data_augmentation",
    action='store_true'
)
parser.add_argument(
    "--swag_posterior_weight",
    type=float,
    default=1e-3,
    required=False,
)
parser.add_argument(
    "--swag_posterior_damp",
    type=float,
    default=1e-4,
    required=False,
)

# make directory for logging

args = parser.parse_args()

### CIFAR10 channel means and stddevs
# TODO: change these to imagenet
# channel_means = np.array([0.4914, 0.4822, 0.4465])
# channel_stds = np.array([(0.2023, 0.1994, 0.2010)])
channel_means = (0.485, 0.456, 0.406)
channel_stds = (0.229, 0.224, 0.225)

n_classes = 1000

# use local device count here even with tpus?
n_devices = jax.device_count()

batch_size = args.batch_size
def preprocess_inputs(datapoint):
    image, label = datapoint['image'], datapoint['label']
    image = image / 255
    image = (image - channel_means) / channel_stds
    # if datapoint['id'] == id_0:
        # label = (label + 1) % 10
    label = tf.one_hot(label, n_classes) 
    return image, label

def augment_train_data(image, label):
    # return image, label
    # if args.use_data_augmentation:
        # image = tf.image.resize_with_crop_or_pad(image, 36, 36)
        # image = tf.image.random_crop(image, size=(32, 32, 3))
        # image = tf.image.random_flip_left_right(image)
        # # janky label smoothing
        # label += 0.005
    return image, label

model = get_model(args.model, n_classes)
### removes RNG component and runs with is_training=True
@jit
def net_apply(params, state, rng, x):
    return model.apply(params, state, rng, x, True)

@jit
def net_apply_eval(params, state, x):
    return model.apply(params, state, None, x, False)

@jit
def net_apply_eval_bn(params, state, x):
    return model.apply(params, state, None, x, True)

rng = random.PRNGKey(0)
rng = np.broadcast_to(rng, (n_devices,) + rng.shape)
# thing = next(iter(tfds.as_numpy(ds_train)))
# import ipdb; ipdb.set_trace()
# initializes copies of parameters and states on each device
# init_params, init_state = pmap(lambda rng, x: model.init(rng, x, is_training=True))(rng, next(iter(tfds.as_numpy(ds_train)))[0])

# with open(args.load_file, 'rb') as f:
    # single_params, single_state = pickle.load(f)

swag_filename = 'imagenet_models/swag_models/seed{}/saved_swag_state.pkl'.format(args.seed)
with open(swag_filename, 'rb') as f:
    swag_state = pickle.load(f)

swag_state_filename = 'imagenet_models/swag_models/seed{}/saved_swa_net_state.pkl'.format(args.seed)
with open(swag_state_filename, 'rb') as f:
    single_state = pickle.load(f)

# swag_means, swag_vars = collect_posterior(swag_state)
swag_means, swag_vars = collect_posterior(swag_state)
del swag_state
single_params = swag_means
# single_state = swag_state
# del swag_state

# verify thes is the correct architecture
init_params, init_state = single_params, single_state # tree_map(lambda x: x[None], (single_params, single_state))
# init_params, init_state = tree_map(lambda x: x[None], (single_params, single_state))
# init_params = tree_map(lambda x: x[None], swag_means)
net_state = init_state



### hyperparameters, TODO: add flags
num_epochs = args.n_epochs

# creates optimizer
# lr schedule taken from the SWAG paper for VGG16, seems a bit suboptimal
# needs to be adjusted for distributed training since larger batch sizes are used
def step_size_schedule(i):
    ### VGG16
    return args.lr

if args.adapt_bn_only:
    all_param_names = init_params.keys()
    bn_params, other_params = hk.data_structures.partition(lambda m, n, p: 'batchnorm' in m, init_params)
    # other_params = get_single_copy(other_params)
    orig_net_apply = net_apply
    orig_net_apply_eval = net_apply_eval
    orig_net_apply_eval_bn = net_apply_eval_bn

    bn_only_net_apply = lambda bn_p, state, rng, x: orig_net_apply(hk.data_structures.merge(bn_p, other_params), state,  rng, x)
    net_apply = jit(bn_only_net_apply)
    bn_only_net_apply_eval = lambda bn_p, state, x: orig_net_apply_eval(hk.data_structures.merge(bn_p, other_params), state, x)
    net_apply_eval = jit(bn_only_net_apply_eval)

    bn_only_net_apply_eval_bn = lambda bn_p, state, x: orig_net_apply_eval_bn(hk.data_structures.merge(bn_p, other_params), state, x)
    net_apply_eval_bn = jit(bn_only_net_apply_eval_bn)
    net_params = bn_params
    print("Working with adapt bn only", flush=True)
else:
    net_params = init_params

if args.use_swag_posterior:
    print("Using swag posterior")
    regularizer = lambda params: args.swag_posterior_weight * weighted_parameter_loss(params, swag_means, swag_vars, args.swag_posterior_damp)
    regularizer(single_params)
else:
    regularizer = None

opt_init, opt_update, get_params = optimizers.momentum(step_size=step_size_schedule, mass=0.9, wd=args.wd)
opt_state = opt_init(net_params)

corruption_str = '{}_{}'.format(args.corruption_type, args.corruption_level)

ds_train = tfds.load('imagenet2012_corrupted/{}'.format(corruption_str), split='validation', data_dir=args.data_dir).shuffle(50000, seed=0, reshuffle_each_iteration=True).map(preprocess_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE).map(augment_train_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
ds_test = tfds.load('imagenet2012_corrupted/{}'.format(corruption_str), split='validation', data_dir=args.data_dir).map(preprocess_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE).map(augment_train_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(128, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

options = ds_train.options()
options.experimental_threading.private_threadpool_size = 48
options.experimental_threading.max_intra_op_parallelism = 1

# no parallelization at evaluation to ensure all examples get processed
# this slows down eval quite a bit, but I'm too lazy to handle this more carefully
rng = random.PRNGKey(args.seed)
rng = np.broadcast_to(rng, (n_devices,) + rng.shape)

def eval(eval_params, eval_net_state, with_logits=False):
    start = time.time()
    if with_logits:
        test_results, logits = eval_ds_all(tfds.as_numpy(ds_test), 
                               eval_params, 
                               eval_net_state, 
                               net_apply_eval, 
                               (nll, entropy, accuracy, brier, ece),
                               with_logits)
        return test_results, logits
    test_results = eval_ds_all(tfds.as_numpy(ds_test), 
                           eval_params, 
                           eval_net_state, 
                           net_apply_eval, 
                           (nll, entropy, accuracy, brier, ece),
                           with_logits)
    return test_results, test_results

def eval_bn(eval_params, eval_net_state, with_logits=False):
    start = time.time()

    if with_logits:
        test_results, logits = eval_ds_all(tfds.as_numpy(ds_test), 
                               eval_params, 
                               eval_net_state, 
                               net_apply_eval_bn, 
                               (nll, entropy, accuracy, brier, ece),
                               with_logits)
        return test_results, logits
    test_results = eval_ds_all(tfds.as_numpy(ds_test), 
                           eval_params, 
                           eval_net_state, 
                           net_apply_eval_bn, 
                           (nll, entropy, accuracy, brier, ece),
                           with_logits)
    return test_results, test_results

eval_params = get_params(opt_state)
eval_params, eval_net_state = eval_params, net_state # get_single_copy((eval_params, net_state))

bn_only_str = 'adaptbnonly_' if args.adapt_bn_only else ''

filename = 'logs/entropy_minimization_imagenet_online/{}/posteriorweight{}_posteriordamp{}_{}lr{}_batchsize{}/seed{}_{}.pkl'.format(args.model, args.swag_posterior_weight, args.swag_posterior_damp, bn_only_str, args.lr, args.batch_size, args.seed, corruption_str)
os.makedirs(os.path.dirname(filename), exist_ok=True)
print(filename, flush=True)
try:
    pickle.load(open(filename, 'rb'))
    print(filename, 'file loaded')
    # import ipdb; ipdb.set_trace()
except:
    print(filename, 'file not found')

t = time.time()
log_dict = {}

rng = random.PRNGKey(args.seed)
for epoch in range(num_epochs):
    # constructs numpy iterator
    start = time.time()
    np_ds = tfds.as_numpy(ds_train)
    opt_state, net_state, train_loss, all_logits, all_labels = train_epoch_online(epoch, 
                                                   opt_state, 
                                                   net_state, 
                                                   rng,
                                                   np_ds, 
                                                   entropy, 
                                                   get_params, 
                                                   net_apply, 
                                                   opt_update, 
                                                   regularizer=regularizer,
                                                   distributed=False)
    print('Epoch {}: {} {}'.format(epoch, train_loss, time.time() - start), flush=True)
    if epoch % 1 == 0:
        all_logits = np.concatenate(all_logits)
        all_labels = np.concatenate(all_labels)

        online_results = [s(all_logits, all_labels) for s in (nll, entropy, accuracy, brier, ece)]
        log_dict['Online_{} Test'.format(epoch)] = online_results
        log_dict['Online_{} Train'.format(epoch)] = online_results
        print("Online {}".format(epoch), online_results, time.time() - start)
        final_logits_filename = 'logs/entropy_minimization_imagenet_online/{}/posteriorweight{}_posteriordamp{}_{}lr{}_batchsize{}/seed{}_{}_final_logits.npy'.format(args.model, args.swag_posterior_weight, args.swag_posterior_damp, bn_only_str, args.lr, args.batch_size, args.seed, corruption_str)
        np.save(final_logits_filename, all_logits)

print(corruption_str)

pickle.dump(log_dict, open(filename, 'wb'))

