"""Code for training interaction model."""
import glob
import json
import logging
import os

import google
import h5py
import numpy as np
import tqdm
from collections import defaultdict

import src.feat.sequence as seq
import src.learning.interaction.pair_to_tfrecord as ptt
import src.learning.interaction.model as im
import src.learning.interaction.model_params as ip
import src.learning.interaction.model_runner as run
import src.learning.interaction.seqmodel_params as qp
import src.learning.interaction.train_params as tp
import src.learning.subgrid_generation as sg


def _run_and_log(mr, epoch_idx, writer, training, num_batches):
    """Run a single time through the provided model runner."""
    if training:
        progress_format = 'Loss: {:6.4f}'
    else:
        progress_format = 'Val Loss: {:6.4f}'

    check_sum = 0
    epoch_loss, epoch_acc = 0, 0
    generator_times, learning_times, losses, accs = [], [], [], []
    t = tqdm.trange(num_batches, desc=progress_format.format(0))
    for i, batch_idx in enumerate(t):
        bi = mr.next()
        check_sum += bi.check_sum
        generator_times.append(bi.elapsed_generator)

        step = epoch_idx * num_batches + i  # tf.train.get_global_step()

        if training and bi.run_metadata is not None:
            writer.add_run_metadata(
                bi.run_metadata,
                'step{}'.format(step))

            if False:
                # This is weird, can't seem to use tf.python.client
                # directly, suspect it is some sort of bug in TF or
                # because it is an experimental feature.
                if not os.path.exists(mr.model.model_name):
                    os.mkdir(mr.model.model_name)

                from tensorflow.python.client import timeline
                tl = timeline.Timeline(bi.run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open('{}/timeline_parallel_{}.json'
                          .format(mr.model.model_name, step), 'w') as f:
                    f.write(ctf)
                with open('{}/run_metadata_{}.txt'
                          .format(mr.model.model_name, step), 'w') as f:
                    f.write(google.protobuf.text_format.MessageToString(
                        bi.run_metadata))
        writer.add_summary(bi.summary, step)
        learning_times.append(bi.elapsed_learning)
        epoch_loss += (bi.loss - epoch_loss) / (i + 1)
        epoch_acc += (bi.acc - epoch_acc) / (i + 1)
        losses.append(bi.loss)
        accs.append(bi.acc)
        t.set_description(progress_format.format(epoch_loss))

    _epoch_summary(writer, epoch_idx, 'epoch loss', epoch_loss)
    _epoch_summary(writer, epoch_idx, 'epoch acc', epoch_acc)

    logging.info('check_sum: {}'.format(check_sum))
    logging.info('Histogram: {}'
                 .format(np.histogram(learning_times[2:])[0]))
    logging.info('Histogram: {}'
                 .format(np.histogram(learning_times[2:])[1]))
    logging.info('Median time for session: {:5.2f}'
                 .format(np.median(learning_times)))
    return epoch_loss, epoch_acc, losses, accs, generator_times, learning_times


def _epoch_summary(writer, epoch_idx, tag, value):
    import tensorflow as tf
    summary = tf.Summary()
    summary_value = summary.value.add()
    summary_value.simple_value = value
    summary_value.tag = tag
    writer.add_summary(summary, epoch_idx)


def parse_counts_file(fpath):
    with open(fpath, 'r') as f:
        lines = f.readlines()

    all_cts = defaultdict(lambda: defaultdict(int))

    for line in lines:
        split = line.strip().split(' ')
        tfr = int(split[0])
        pdbs = split[1:]
        for pdb in pdbs:
            pdbsplit = pdb.split(',')
            pdbid = pdbsplit[0].split('_')[0]
            num_pos = int(pdbsplit[1])
            all_cts[tfr][pdbid] += num_pos
    return all_cts


def _get_tfrecords_to_use(train_params):
    """Select which tfrecords chunks to use."""
    # Select tfrecord files to use to construct tf.Dataset for train and val.
    if (not train_params['loose']) and \
       train_params['keep_file_training'] != '' and \
       train_params['keep_file_validation'] != '':
        train_pdbids = set(line.strip() for line in
                           open(train_params['keep_file_training'], 'r'))
        val_pdbids = set(line.strip() for line in
                         open(train_params['keep_file_validation'], 'r'))

        num_train_examples = train_params['range_training'][1] - \
            train_params['range_training'][0]
        num_val_examples = train_params['range_validation'][1] - \
            train_params['range_validation'][0]

        dataset_dir = os.path.dirname(train_params['dataset_tfrecords'])
        counts_file = os.path.join(dataset_dir, 'dataset_written.txt')
        all_cts = parse_counts_file(counts_file)

        train_examples_ctr = 0
        val_examples_ctr = 0
        for tfr in sorted(all_cts.keys()):
            for pdbid in all_cts[tfr].keys():
                if pdbid in train_pdbids:
                    train_examples_ctr += all_cts[tfr][pdbid]
                elif pdbid in val_pdbids:
                    val_examples_ctr += all_cts[tfr][pdbid]
            if train_examples_ctr >= num_train_examples and \
                    val_examples_ctr >= num_val_examples:
                break

        logging.info(
            'Using {} TFRecord files for {} train examples and {} val examples'
            .format(tfr+1, num_train_examples, num_val_examples))
        tr_tfrecords = [os.path.join(
            dataset_dir, 'dataset_{:03}.tfrecord'.format(i))
            for i in range(tfr)]
        val_tfrecords = [os.path.join(
            dataset_dir, 'dataset_{:03}.tfrecord'.format(i))
            for i in range(tfr)]
    else:
        tr_tfrecords = glob.glob(
            train_params['dataset_tfrecords'] + '/*.tfrecord')
        val_tfrecords = glob.glob(
            train_params['val_dataset_tfrecords'] + '/*.tfrecord')
    return tr_tfrecords, val_tfrecords


def train(args):
    """Train an interaction predictor."""
    import tensorflow as tf
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)

    model_params = ip.load_params(args.model_json)
    train_params = tp.load_params(args.train_json)

    has_seq = args.seqmodel_json != ""
    if has_seq:
        assert os.path.exists(args.seqmodel_json)
        seqmodel_params = qp.load_params(args.seqmodel_json)

    model = im.InteractionModel(args.model_dir)
    if not model.is_initialized():
        logging.info("initializing new model...")
        if not has_seq:
            model.initialize(model_params)
        else:
            model.add_seq(seqmodel_params)

    # Create 3D grid generator.
    gen = sg.TFSubgridGenerator(
        model_params, train_params['num_directions'],
        train_params['num_rolls'])

    tr_tfrecords, val_tfrecords = _get_tfrecords_to_use(train_params)

    # Create tf.Dataset for reading in train and validation sets.
    tr_dataset, num_batches_training = ptt.create_tf_dataset(
        train_params, 'training', tr_tfrecords, gen.get_gridded_pair)
    val_dataset, num_batches_validation = ptt.create_tf_dataset(
        train_params, 'validation', val_tfrecords, gen.get_gridded_pair)

    if has_seq:
        tr_dataset = seq.add_seq_information(
            tr_dataset, train_params['seq_src'], seqmodel_params)
        val_dataset = seq.add_seq_information(
            val_dataset, train_params['seq_src'], seqmodel_params)

    tr_iterator = tr_dataset.make_one_shot_iterator()
    val_iterator = val_dataset.make_one_shot_iterator()

#    # TODO: Remove this intermediate catch.
#    next_el = tr_iterator.get_next()
#    with tf.Session() as sess:
#        val = sess.run(next_el)
#        import pdb; pdb.set_trace()
#            pass
#    return


    # Feedable iterator.
    handle = tf.placeholder(tf.string, shape=[], name='handle')
    iterator = tf.data.Iterator.from_string_handle(
        handle, tr_dataset.output_types, tr_dataset.output_shapes)

    log_metadata = False
    model.load(train_params['towers'], train_params=train_params,
               log_metadata=log_metadata, iterator=iterator, has_seq=has_seq)

    logging.info("Creating data generators")

    tr_mr = run.ModelRunner(model, tr_iterator, True)
    val_mr = run.ModelRunner(model, val_iterator, False)

    logging.info("Done creating data generators")

    logging.info("Using {:} gpus.".format(train_params['towers']))
    start_epoch = model.tracker.get_next_epoch()
    epoch_format = model.tracker._get_checkpoint_format()
    prev_val_loss = float('inf')
    train_writer = tf.summary.FileWriter(args.model_dir + '/train',
                                         model.sess.graph)
    val_writer = tf.summary.FileWriter(args.model_dir + '/test')

    epoch_idx = start_epoch
    while True:
        if epoch_idx - start_epoch == train_params['max_epochs']:
            # If we have run the requested number of epochs.
            break
        logging.info("Epoch {}".format(epoch_idx))
        epoch_tr_loss, epoch_tr_acc, tr_losses, tr_accs, \
            tr_generator_times, tr_learning_times = \
            _run_and_log(tr_mr, epoch_idx, train_writer, True,
                         num_batches_training)
        epoch_val_loss, epoch_val_acc, val_losses, val_accs, \
            val_generator_times, val_learning_times = \
            _run_and_log(val_mr, epoch_idx, val_writer, False,
                         num_batches_validation)

        logging.info('Train loss: {:6.4f}, Val loss: {:6.4f}'
                     .format(epoch_tr_loss, epoch_val_loss))

        model.save(epoch_format.format(epoch=epoch_idx,
                                       val_loss=epoch_val_loss))
        with open(model.tracker.loss_summary_filename(epoch_idx), 'w') as f:
            json.dump({'loss': epoch_tr_loss,
                       'val_loss': epoch_val_loss}, f)

        epoch_detail = model.tracker.epoch_detail_filename(epoch_idx)
        with h5py.File(epoch_detail, 'w') as f:
            grp = f.create_group('train')
            grp.create_dataset('loss', data=np.array(tr_losses))
            grp.create_dataset('session_time',
                               data=np.array(tr_learning_times))
            grp.create_dataset('generator_time',
                               data=np.array(tr_generator_times))
            grp = f.create_group('val')
            grp.create_dataset('loss', data=np.array(val_losses))
            grp.create_dataset('session_time',
                               data=np.array(val_learning_times))
            grp.create_dataset('generator_time',
                               data=np.array(val_generator_times))

        sc = train_params['stop_criteria']
        if 'lr' in sc and epoch_val_loss > prev_val_loss:
            logging.info('Loss stopped decreasing, slowing learning rate...')
            model.slow_lr_by_factor(10)
        if 'converge' in sc and sc['converge'] and \
                epoch_val_loss > prev_val_loss:
            logging.info('Loss stopped decreasing, stopping...')
            break
        if 'train_threshold' in sc and \
                epoch_tr_loss < sc['train_threshold']:
            logging.info('Train loss below threshold, stopping...')
            break
        if 'val_threshold' in sc and epoch_val_loss < sc['val_threshold']:
            logging.info('Validation loss below threshold, stopping...')
            break

        epoch_idx += 1
        prev_val_loss = epoch_val_loss
        print
