"""Model code for interaction prediction."""
import cPickle as pickle
import logging
import math
import os
import time

import numpy as np

import src.learning.epoch_tracker as et
import src.learning.subgrid_generation as sg


# Max number of GPUs we would use this model for (on one node).
MAX_TOWERS = 8


def _is_new_format():
    """Are we in new tower naming format."""
    import tensorflow as tf
    # We now store things as TOWER0 instead of TOWER_0.
    return len(
        [n.name for n in tf.get_default_graph().as_graph_def().node
         if 'TOWER0' in n.name]) > 0


def _tower_num(i):
    """Get tower number."""
    return _tower_num_with_format(i, _is_new_format())


def _tower_num_with_format(i, is_new_format):
    """Get tower number of given format."""
    if is_new_format:
        return 'TOWER' + str(i)
    else:
        return 'TOWER_' + str(i)


def _add_training_ops(towers, train_params):
    """Define trainable model across gpus."""
    import tensorflow as tf
    has_seq = train_params['seq_src'] != ""

    if train_params['lr_decay_type'] == 'none':
        learning_rate_tensor = tf.convert_to_tensor(
            train_params['learning_rate'])
        learning_rate = tf.get_variable(
            'learning_rate', initializer=learning_rate_tensor, trainable=False)

    elif train_params['lr_decay_type'] == 'exponential':
        learning_rate = tf.train.exponential_decay(
            train_params['learning_rate'],
            tf.train.get_or_create_global_step(),
            train_params['lr_decay_steps'],
            train_params['lr_decay_rate'],
            train_params['lr_decay_staircase'])
    elif train_params['lr_decay_type'] == 'inverse_time':
        learning_rate = tf.train.inverse_time_decay(
            train_params['learning_rate'],
            tf.train.get_or_create_global_step(),
            train_params['lr_decay_steps'],
            train_params['lr_decay_rate'],
            train_params['lr_decay_staircase'])
    elif train_params['lr_decay_type'] == 'natural_exp':
        learning_rate = tf.train.natural_exp_decay(
            train_params['learning_rate'],
            tf.train.get_or_create_global_step(),
            train_params['lr_decay_steps'],
            train_params['lr_decay_rate'],
            train_params['lr_decay_staircase'])
    elif train_params['lr_decay_type'] == 'cosine':
        learning_rate = tf.train.cosine_decay(
            train_params['learning_rate'],
            tf.train.get_or_create_global_step(),
            train_params['lr_decay_steps'],
            train_params['lr_decay_rate'])
    else:
        raise RuntimeError('Unrecognized decay type {:}'.format(
            train_params['lr_decay_type']))

    if train_params['optimizer'] == 'RMS':
        optim = tf.train.RMSPropOptimizer(
            learning_rate, epsilon=1e-8)
    elif train_params['optimizer'] == 'SGD':
        optim = tf.train.MomentumOptimizer(
            learning_rate, 0.9, use_nesterov=train_params['nesterov'])
    elif train_params['optimizer'] == 'Adam':
        optim = tf.train.AdamOptimizer(
            learning_rate)

    if has_seq:
        prefix = 'seq/'
    else:
        prefix = ''
    tower_losses = []
    for i in range(towers):
        tower_losses.append(tf.get_default_graph().get_tensor_by_name(
                prefix + '{:}/cross_entropy:0'.format(_tower_num(i))))

    tower_grads = []
    to_train = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, scope=prefix)
    for i, loss in enumerate(tower_losses):
        with tf.device('/gpu:{:}'.format(i)):
            with tf.name_scope('TOWER{:}'.format(i)):
                grads = optim.compute_gradients(loss, var_list=to_train)
                tower_grads.append(grads)

    with tf.device('/cpu:0'):
        grads = _average_gradients(tower_grads)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=prefix)
        update_ops_to_use = []
        for i in range(towers):
            tower_name = _tower_num(i)
            update_ops_to_use += \
                [x for x in update_ops if tower_name in x.name]
        updates = tf.group(*update_ops_to_use)
        with tf.control_dependencies([updates]):
            optim.apply_gradients(grads, name='train_step',
                                  global_step=tf.train.get_global_step())
    return optim, learning_rate


def _initialize_towers(model_params):
    """Create model graph and weights."""
    import tensorflow as tf
    with tf.device('/cpu:0'):
        tf.placeholder(tf.float32, name='fc_keep')
        tf.placeholder(tf.float32, name='conv_keep')
        tf.placeholder(tf.float32, name='top_nn_keep')
        tf.placeholder(tf.bool, shape=[], name='is_training')
        for i in range(MAX_TOWERS):
            with tf.device('/gpu:{:}'.format(i)):
                with tf.name_scope('TOWER{:}'.format(i)):
                    _define_tower(model_params)
                    tf.get_variable_scope().reuse_variables()


def _add_sequence_network(seqmodel_params):
    """Create new model by adding on to structural model."""
    import tensorflow as tf

    with tf.device('/cpu:0'):
        tf.placeholder(tf.bool, shape=[], name='is_training')
        with tf.variable_scope('seq'):
            for i in range(MAX_TOWERS):
                with tf.device('/gpu:{:}'.format(i)):
                    with tf.name_scope('TOWER{:}'.format(i)):
                        _add_seq_tower(seqmodel_params)
                        tf.get_variable_scope().reuse_variables()


def _add_averaging_ops(towers, has_seq):
    """Get average tensors for all towers provided."""
    import tensorflow as tf
    if has_seq:
        prefix = 'seq/'
    else:
        prefix = ''

    tower_accs = []
    tower_losses = []
    tower_labels = []
    tower_outs = []
    graph = tf.get_default_graph()
    for i in range(towers):
        tower_name = _tower_num(i)
        tower_accs.append(graph.get_tensor_by_name(
            '{:}{:}/evaluation/accuracy:0'.format(prefix, tower_name)))
        tower_outs.append(graph.get_tensor_by_name(
            '{:}{:}/y_out:0'.format(prefix, tower_name)))
        tower_labels.append(graph.get_tensor_by_name(
            '{:}{:}/y:0'.format(prefix, tower_name)))
        tower_losses.append(graph.get_tensor_by_name(
            '{:}{:}/cross_entropy:0'.format(prefix, tower_name)))

    with tf.name_scope(prefix + 'all'):
        with tf.name_scope('accuracy'):
            expanded_accs = tf.expand_dims(tower_accs, 0)
            all_accs = tf.concat(expanded_accs, 0, name='all')
            final_acc = tf.reduce_mean(all_accs, name='mean')
            tf.summary.scalar('summary', final_acc)

        with tf.name_scope('out'):
            all_outs = tf.concat(tower_outs, 0, name='all')
            _variable_summaries(all_outs)

        with tf.name_scope('y'):
            all_labels = tf.concat(tower_labels, 0, name='all')
            _variable_summaries(all_labels)

        with tf.name_scope('loss'):
            expanded_losses = tf.expand_dims(tower_losses, 0)
            all_losses = tf.concat(expanded_losses, 0, name='all')
            final_loss = tf.reduce_mean(all_losses, name='mean')
            tf.summary.scalar('summary', final_loss)

    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
    summaries_to_use = []
    for i in range(towers):
        tower_name = _tower_num(i)
        summaries_to_use += \
            [x for x in summaries if tower_name in x.name]
    for s in summaries_to_use:
        tf.add_to_collection('summaries_to_use', s)

    if hasattr(tf.summary, 'merge_all'):
        merged = tf.summary.merge_all('summaries_to_use')
    else:
        merged = tf.merge_all_summaries('summaries_to_use')
    merged = tf.identity(merged, 'merged')


def _average_gradients(tower_grads):
    """
    Calculate the average gradient for each shared variable across all towers.

    Note that this function provides a synchronization point across all towers.
    Taken from:
    https://github.com/tensorflow/tensorflow/blob/v0.9.0/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py

    Args:
        tower_grads: List of lists of (gradient, variable) tuples. The outer
            list is over individual gradients. The inner list is over the
            gradient calculation for each tower.

    Returns:
        List of pairs of (gradient, variable) where the gradient has been
        averaged across all towers.
    """
    import tensorflow as tf

    # import pdb
    # pdb.set_trace()

    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        # Note that each grad_and_vars looks like the following:
        #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
        grads = []

        if False:
            for g, _ in grad_and_vars:
                # Append on a 'tower' dimension which we will average over
                # below.
                grads.append(g)

            # Average over the 'tower' dimension.
            grad = tf.add_n(grads)
            grad = tf.div(grad, tf.constant(2.0))
        else:
            for g, _ in grad_and_vars:
                # Add 0 dimension to the gradients to represent the tower.
                expanded_g = tf.expand_dims(g, 0)

                # Append on a 'tower' dimension which we will average over
                # below.
                grads.append(expanded_g)

            # Average over the 'tower' dimension.
            grad = tf.concat(grads, 0, name="concated")
            grad = tf.reduce_mean(grad, 0, name="reduced")

        # Keep in mind that the Variables are redundant because they are shared
        # across towers. So .. we will just return the first tower's pointer to
        # the Variable.
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads


def _define_tower(model_params):
    """Define siamese model."""
    import tensorflow as tf

    grid_left = tf.placeholder(
        tf.float32,
        shape=(None,) + sg.grid_shape(model_params),
        name='grid_left')
    grid_right = tf.placeholder(
        tf.float32,
        shape=(None,) + sg.grid_shape(model_params),
        name='grid_right')
    y = tf.placeholder(tf.float32, shape=[None], name='y')

    with tf.variable_scope("base_networks") as scope:
        processed_left = _create_base_network(grid_left, model_params)
        scope.reuse_variables()
        processed_right = _create_base_network(grid_right, model_params)

    concat = tf.concat([processed_left, processed_right], 1, name='concat')
    with tf.name_scope('concat'):
        x = concat
        _variable_summaries(concat)

    curr_in_size = concat.shape[1]
    assert model_params['dense_layers'] == \
        len(model_params['top_fc_nodes'])
    # Deep non siamese.
    with tf.variable_scope("top_nn"):
        for i, curr_size in enumerate(model_params['top_fc_nodes']):
            with tf.variable_scope("fc{:d}".format(i)):
                W_fc = tf.get_variable(
                    "weights",
                    [curr_in_size, curr_size],
                    initializer=tf.contrib.layers.xavier_initializer())
                b_fc = tf.get_variable(
                    "biases",
                    [curr_size],
                    initializer=tf.constant_initializer(0.0))

                x = tf.nn.relu(tf.matmul(x, W_fc) + b_fc, name='fcrelu')
                if model_params['batch_norm']:
                    x = _batch_norm(x)
                if model_params['dropout']:
                    x = tf.nn.dropout(
                        x, tf.get_default_graph().get_tensor_by_name(
                            'top_nn_keep:0'))
                curr_in_size = curr_size

        _variable_summaries(x)

    with tf.variable_scope("prediction"):
        W_final = tf.get_variable(
            "weights",
            [curr_size, 1],
            initializer=tf.contrib.layers.xavier_initializer())
        b_final = tf.get_variable(
            "biases",
            [1],
            initializer=tf.constant_initializer(0.0))
        x = tf.squeeze(tf.matmul(x, W_final) + b_final, squeeze_dims=[1])
        _variable_summaries(x)

    y_out = tf.sigmoid(x, name='y_out')
    _variable_summaries(y_out)
    cross_entropy = \
        tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y),
            name='cross_entropy')

    with tf.variable_scope("evaluation"):
        correct_prediction =  \
            tf.logical_or(
                tf.logical_and(
                    tf.less_equal(y_out, 0.5),
                    tf.less_equal(y, 0.5)),
                tf.logical_and(
                    tf.greater(y_out, 0.5),
                    tf.greater(y, 0.5)))

        tf.reduce_mean(tf.cast(correct_prediction, tf.float32),
                       name='accuracy')

    return cross_entropy


def _add_seq_tower(seqmodel_params):
    import tensorflow as tf
    radius = seqmodel_params['cons_window_radius']
    size = radius * 2 + 1
    pssm_left = tf.placeholder(
        tf.int32,
        shape=(None,) + (size, 20),
        name='pssm_left')
    pssm_right = tf.placeholder(
        tf.int32,
        shape=(None,) + (size, 20),
        name='pssm_right')
    psfm_left = tf.placeholder(
        tf.float32,
        shape=(None,) + (size, 20),
        name='psfm_left')
    psfm_right = tf.placeholder(
        tf.float32,
        shape=(None,) + (size, 20),
        name='psfm_right')
    y = tf.placeholder(tf.float32, shape=[None], name='y')

    with tf.variable_scope("base_networks") as scope:
        processed_left = _create_base_seq_network(
            pssm_left, psfm_left, seqmodel_params)
        scope.reuse_variables()
        processed_right = _create_base_seq_network(
            pssm_right, psfm_right, seqmodel_params)

    use_seq = 'pssm' in seqmodel_params['source'] or \
        'psfm' in seqmodel_params['source']
    use_struct = 'struct' in seqmodel_params['source']

    if use_seq:
        with tf.variable_scope("concat") as scope:
            seq_embedding = \
                tf.concat([processed_left, processed_right], 1, name='concat')

    TOWER = tf.contrib.framework.get_name_scope().split("/")[1]
    struct_embedding = tf.get_default_graph().get_operation_by_name(
        TOWER + '/prediction/MatMul').inputs[0]

    if use_seq and use_struct:
        with tf.variable_scope("seq_struct_concat") as scope:
            embedding = tf.concat([seq_embedding, struct_embedding], 1)
    elif use_struct:
        embedding = struct_embedding
    elif use_seq:
        embedding = seq_embedding
    else:
        raise RuntimeError(
            "Need to specify one of struct or sequence as source!")

    x = embedding
    curr_size = embedding.shape[1]

    with tf.variable_scope("prediction"):
        W_final = tf.get_variable(
            "weights",
            [curr_size, 1],
            initializer=tf.contrib.layers.xavier_initializer())
        b_final = tf.get_variable(
            "biases",
            [1],
            initializer=tf.constant_initializer(0.0))
        x = tf.squeeze(tf.matmul(x, W_final) + b_final, squeeze_dims=[1])
        _variable_summaries(x)

    y_out = tf.sigmoid(x, name='y_out')
    _variable_summaries(y_out)
    cross_entropy = \
        tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y),
            name='cross_entropy')

    with tf.variable_scope("evaluation"):
        correct_prediction =  \
            tf.logical_or(
                tf.logical_and(
                    tf.less_equal(y_out, 0.5),
                    tf.less_equal(y, 0.5)),
                tf.logical_and(
                    tf.greater(y_out, 0.5),
                    tf.greater(y, 0.5)))

        tf.reduce_mean(tf.cast(correct_prediction, tf.float32),
                       name='accuracy')

    return cross_entropy


def _variable_summaries(var):
    """Attach a lot of summaries to a Tensor."""
    import tensorflow as tf

    mean = tf.reduce_mean(var)
    with tf.name_scope('stddev'):
        stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
    tf.summary.scalar('mean', mean)
    tf.summary.scalar('sttdev', stddev)
    tf.summary.scalar('max', tf.reduce_max(var))
    tf.summary.scalar('min', tf.reduce_min(var))
#       tf.histogram_summary(name, var)


def _conv3d(x, kernel_shape, bias_shape, model_params, residual_input=None):
    """3D Convolution Layer."""
    import tensorflow as tf

    weights = tf.get_variable(
        "weights",
        kernel_shape,
        initializer=tf.contrib.layers.xavier_initializer())
    biases = tf.get_variable("biases", bias_shape,
                             initializer=tf.constant_initializer(0.0))

    if model_params['border_mode_valid']:
        border_mode = 'VALID'
    else:
        border_mode = 'SAME'
    x = tf.nn.conv3d(x, weights, [1, 1, 1, 1, 1], border_mode)
    x = tf.add(x, biases)
    if residual_input is not None:
        x_channels = x.shape[-1]
        input_channels = residual_input.shape[-1]
        if x_channels != input_channels:
            lin_proj_weights = tf.get_variable(
                "lin_proj_weights",
                [1, 1, 1, input_channels, x_channels],
                initializer=tf.contrib.layers.xavier_initializer())
            lin_proj_biases = tf.get_variable(
                "lin_proj_biases", [x_channels],
                initializer=tf.constant_initializer(0.0))
            residual_input = tf.nn.conv3d(
                residual_input, lin_proj_weights, [1, 1, 1, 1, 1], 'VALID')
            residual_input = tf.add(residual_input, lin_proj_biases)
        x = tf.add(x, residual_input)
    x = tf.nn.relu(x)
    return x


def _batch_norm(x):
    """Batchnormalization layer."""
    import tensorflow as tf
    is_training = tf.get_default_graph().get_tensor_by_name('is_training:0')

    def _debug_print(x, scope_name, prev=None):
        if not np.isfinite(x).all():
            logging.error("logging crash at {}".format(scope_name))
            if prev is None:
                logging.error("no prev provided")
            import h5py
            curr_time = time.strftime("%Y%m%d-%H%M%S")
            scope_name = scope_name.replace('/', '-')
            crash_file = "crash-{}-{}.h5".format(scope_name, curr_time)
            with h5py.File(crash_file, 'w') as f:
                if prev is not None:
                    f.create_dataset('prev', data=prev)
                f.create_dataset('x', data=x)
        return True

    with tf.variable_scope('bn') as scope:
        # prev = x

        # Since fused_batch_norm only takes in 2D or 4D vectors, reshape our 5D
        # vector to 4D, as per:
        # https://github.com/tensorflow/tensorflow/issues/5694
        x_shape = x.get_shape().as_list()
        if len(x_shape) == 5:
            r1 = x_shape[1]
            r2 = x_shape[2]
            r3 = x_shape[3]
            channels = x_shape[4]
            x = tf.reshape(x, [-1, r1, r2 * r3, channels])

            x = tf.contrib.layers.batch_norm(
                x, fused=True, scope=scope, is_training=is_training,
                scale=True, decay=0.99)

            x = tf.reshape(x, [-1, r1, r2, r3, channels])
        elif len(x_shape) == 2:
            # Don't do fused for 2D vectors because otherwise this happens:
            # https://github.com/tensorflow/tensorflow/issues/5988
            # TODO: This has now been fixed.  Make sure we can switch.
            x = tf.contrib.layers.batch_norm(
                x, scope=scope, is_training=is_training, scale=True,
                decay=0.99)
        else:
            # Only support 3D and FC layers.
            assert 0

        '''
        debug_print_op = tf.py_func(_debug_print,
                                    [x, scope.name, prev],
                                    [tf.bool])
        with tf.control_dependencies(debug_print_op):
            x = tf.identity(x)
        '''
        x = tf.check_numerics(x, 'non finite number found!')
        return x


def _create_base_network(x, model_params):
    import tensorflow as tf
    grid_size = int(round((model_params['radius_ang'] * 2 + 1) /
                          model_params['resolution']))
    num_channels = sg.channel_size(model_params)
    with tf.name_scope('input'):
        _variable_summaries(x)
    if model_params['batch_norm']:
        x = _batch_norm(x)
    output_grid_size = grid_size

    filter_sizes = model_params['num_filters']

    # Convs.
    conv_inputs = []
    for i in range(len(filter_sizes)):
        with tf.variable_scope("conv{:d}".format(i)):
            cs = model_params['conv_size']
            keep_prob = \
                tf.get_default_graph().get_tensor_by_name('conv_keep:0')
            conv_inputs.append(x)

            x = _conv3d(
                x,
                [cs, cs, cs, num_channels, filter_sizes[i]],
                [filter_sizes[i]],
                model_params)
            # If border mode is valid, grid size is reduced at every level.
            if model_params['border_mode_valid']:
                output_grid_size = output_grid_size - (cs - 1)

            if model_params['max_pool_positions'][i] == 1:
                size = model_params['max_pool_sizes'][i]
                stride = model_params['max_pool_strides'][i]
                x = tf.nn.max_pool3d(
                    x,
                    [1, size, size, size, 1],
                    [1, stride, stride, stride, 1],
                    padding="SAME",
                    name="fpool{:}".format(i))
                output_grid_size = int(math.ceil(output_grid_size / 2.0))

            if model_params['batch_norm']:
                x = _batch_norm(x)
            if model_params['dropout']:
                x = tf.nn.dropout(x, keep_prob)
            num_channels = filter_sizes[i]
            _variable_summaries(x)

    curr_in_size = output_grid_size ** 3 * num_channels
    x = tf.reshape(x, [-1, curr_in_size])
    # FC 1.
    for i, curr_size in enumerate(model_params['tower_fc_nodes']):
        with tf.variable_scope("fc{:}".format(i)):
            W_fc1 = tf.get_variable(
                "weights",
                [curr_in_size, curr_size],
                initializer=tf.contrib.layers.xavier_initializer())
            b_fc1 = tf.get_variable("biases", [curr_size],
                                    initializer=tf.constant_initializer(0.0))
            x = tf.nn.relu(tf.matmul(x, W_fc1) + b_fc1, name='fcrelu')
            if model_params['batch_norm']:
                x = _batch_norm(x)
            if model_params['dropout']:
                x = tf.nn.dropout(
                    x, tf.get_default_graph().get_tensor_by_name('fc_keep:0'))
            _variable_summaries(x)
            curr_in_size = curr_size

    x = tf.identity(x, name='fcfinal')

    return x


def _conv1d(x, kernel_shape, bias_shape, seqmodel_params):
    """1D Convolution Layer."""
    import tensorflow as tf

    weights = tf.get_variable(
        "weights",
        kernel_shape,
        initializer=tf.contrib.layers.xavier_initializer())
    biases = tf.get_variable("biases", bias_shape,
                             initializer=tf.constant_initializer(0.0))

    x = tf.nn.conv1d(x, weights, 1, 'SAME')
    x = tf.add(x, biases)
    x = tf.nn.relu(x)
    return x


def _create_base_seq_network(pssm, psfm, seqmodel_params):
    """Create network to process PSSM and PSFM data."""
    import tensorflow as tf

    if 'pssm' in seqmodel_params['source'] and \
            'psfm' in seqmodel_params['source']:
        x = tf.concat([tf.cast(pssm, tf.float32), psfm], 1, name='seq_in')
    elif 'pssm' in seqmodel_params['source']:
        x = tf.cast(pssm, tf.float32, name='seq_in')
    elif 'psfm' in seqmodel_params['source']:
        x = tf.cast(psfm, tf.float32, name='seq_in')
    else:
        # Need to put something here, but don't care what it is since will
        # ignore in calling function.
        x = tf.concat([tf.cast(pssm, tf.float32), psfm], 1, name='seq_in')

    curr_size = x.shape[1] * x.shape[2]

    x = tf.reshape(x, [-1, curr_size])
    if seqmodel_params['batch_norm']:
        x = _batch_norm(x)
    return x


class InteractionModel(object):
    """Wrapper for inference and training."""

    def __init__(self, model_dir):
        """Once init is called, still need to call initialize or load."""
        self.model_dir = model_dir
        self.tracker = et.EpochTracker(model_dir)

    def _get_feed_dict(self, training):
        """Get base feed dictionary."""
        if training:
            feed_dict = {
                'top_nn_keep:0': 0.5,
                'conv_keep:0': 0.9,
                'fc_keep:0': 0.75,
                'is_training:0': True}
        else:
            feed_dict = {
                'top_nn_keep:0': 1.0,
                'conv_keep:0': 1.0,
                'fc_keep:0': 1.0,
                'is_training:0': False}
        return feed_dict

    def is_initialized(self):
        """Check if we are already initialized."""
        return self.tracker.has_init() or self.tracker.has_checkpoint()

    def initialize(self, model_params):
        """Initialize a brand new model."""
        import tensorflow as tf

        # Create tower graph.
        _initialize_towers(model_params)

        # Initialize variables.
        sess = tf.Session(
            config=tf.ConfigProto(allow_soft_placement=True))
        if hasattr(tf, 'global_variables_initializer'):
            sess.run(tf.global_variables_initializer())
        else:
            sess.run(tf.initialize_all_variables())

        # Save initializations.
        saver = tf.train.Saver(max_to_keep=10000000,
                               keep_checkpoint_every_n_hours=0.5)
        saver.save(sess, self.model_dir + '/init.ckpt')

        # Cleanup.
        sess.close()
        tf.reset_default_graph()

    def add_seq(self, seqmodel_params):
        """Add sequence network to existing structure model."""
        import tensorflow as tf

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        src_tracker = et.EpochTracker(seqmodel_params['struct_model'])

        logging.info("Adding sequence network.")
        # Load previous structural model.
        saver = tf.train.import_meta_graph(
            src_tracker.metagraph_filename(-1))
        logging.info("Using structural weights from {:}".format(
            src_tracker.best_checkpoint_filename()))
        saver.restore(
            sess, src_tracker.best_checkpoint_filename())

        # Get already defined structural variables to make sure we don't
        # redefine them.
        temp = set(tf.global_variables())

        # Add actual sequence ops.
        _add_sequence_network(seqmodel_params)

        # Initialize new sequence variables.
        sess.run(
            tf.variables_initializer(
                set(tf.global_variables()) - temp))

        # Save initializations.
        saver = tf.train.Saver(max_to_keep=10000000,
                               keep_checkpoint_every_n_hours=0.5)
        saver.save(sess, self.model_dir + '/init.ckpt')

        # Cleanup.
        sess.close()
        tf.reset_default_graph()

    def load(self, towers, train_params=None, get_best=True, log_metadata=True,
             checkpoint_filename="", iterator=None, has_seq=False):
        """Load pre-existing model."""
        import tensorflow as tf

        self.has_seq = has_seq

        logging.info("Starting model loading")
        # Re-create tower graph.
        if self.tracker.has_init():
            input_map = {}
            if iterator is not None:
                # Thanks to
                # https://stackoverflow.com/questions/38618960/tensorflow-how-to-insert-custom-input-to-existing-graph
                # answer at bottom.
                gridded_split = \
                    [iterator.get_next() for _ in range(towers)]

                # Remove division by tower.
                self.gridded_pairs = {}
                for key in gridded_split[0].keys():
                    curr = []
                    for tower in range(towers):
                        curr.append(gridded_split[tower][key])
                    self.gridded_pairs[key] = tf.concat(curr, 0)

                for i, gridded in enumerate(gridded_split):
                    tower = "TOWER{:}".format(i)
                    for j, side in enumerate(("left", "right")):
                        input_map["{:}/grid_{:}:0".format(tower, side)] = \
                            gridded['grid'][:, j]
                    input_map["{:}/y:0".format(tower)] = gridded['label']
                    if self.has_seq:
                        prefix = 'seq/'
                        for j, side in enumerate(("left", "right")):
                            input_map["{:}{:}/pssm_{:}:0".format(
                                prefix, tower, side)] = gridded['pssm'][:, j]
                            input_map["{:}{:}/psfm_{:}:0".format(
                                prefix, tower, side)] = gridded['psfm'][:, j]
                        input_map["{:}{:}/y:0".format(prefix, tower)] = \
                            gridded['label']
            else:
                grid_left = tf.placeholder(
                    tf.float32,
                    shape=(None, None, None, None, None),
                    name="grid_left")
                grid_right = tf.placeholder(
                    tf.float32,
                    shape=(None, None, None, None, None),
                    name="grid_right")
                y = tf.placeholder(tf.float32, shape=[None], name='y')
                size = tf.shape(grid_left)

                # Ceiling to closest multiple of towers.
                pad_size = tf.cast(tf.ceil(tf.cast(
                    tf.expand_dims(size[0], 0), tf.float32) / towers) * towers,
                    tf.int32) - size[0]
                end_pad = tf.concat((pad_size, tf.zeros_like(size[1:])), 0)
                begin_pad = tf.zeros_like(end_pad)
                pad = tf.stack((begin_pad, end_pad), axis=1)
                self.gridded_pairs = grid_left, grid_right
                grid_left = tf.pad(grid_left, pad)
                grid_right = tf.pad(grid_right, pad)
                gls = tf.split(grid_left, towers)
                grs = tf.split(grid_right, towers)
                # TODO: Add pssm and psfm here.
                for i, (gl, gr) in enumerate(zip(gls, grs)):
                    tower = "TOWER{:}".format(i)
                    input_map["{:}/grid_left:0".format(tower)] = gl
                    input_map["{:}/grid_right:0".format(tower)] = gr
                for i, yt in enumerate(tf.split(y, towers)):
                    input_map["{:}/y:0".format(tower)] = yt

            # We use init.ckpt.meta here since we don't want any of the
            # input/gradient information, and subsequent metagraphs have that
            # stuff.
            saver = tf.train.import_meta_graph(
                self.tracker.metagraph_filename(-1),
                input_map=input_map)
        else:
            # We use init.ckpt.meta here since we don't want any of the
            # input/gradient information, and subsequent metagraphs have that
            # stuff.
            saver = tf.train.import_meta_graph(
                self.tracker.metagraph_filename(-1))
        self.sess = tf.Session(
            config=tf.ConfigProto(allow_soft_placement=True))
        # Load tower variables.
        if checkpoint_filename != "":
                saver.restore(self.sess, checkpoint_filename)
        else:
            if not self.tracker.has_checkpoint():
                saver.restore(
                    self.sess, os.path.join(self.model_dir, 'init.ckpt'))
            elif get_best:
                logging.info("Restoring {}".format(
                    self.tracker.best_checkpoint_filename()))
                saver.restore(
                    self.sess, self.tracker.best_checkpoint_filename())
            else:
                saver.restore(
                    self.sess, self.tracker.latest_checkpoint_filename())

        self._is_new_format = _is_new_format()
        # Add averaging and (if needed) training ops.  We don't save these so
        # that we can deal with varying number of towers at runtime.
        _add_averaging_ops(towers, has_seq)
        if train_params is not None:
            temp = set(tf.global_variables())
            optim, learning_rate_var = _add_training_ops(towers, train_params)
            self.optim = optim
            self.learning_rate_var = learning_rate_var
            self.sess.run(
                tf.variables_initializer(
                    set(tf.global_variables()) - temp))

        logging.info("End model loading")

        self.towers = towers
        self.log_metadata = log_metadata
        self._set_tensor_values()

    def slow_lr_by_factor(self, factor):
        import tensorflow as tf
        curr_lr = self.sess.run([self.learning_rate_var])[0]
        logging.info('Curr LR is {}'.format(curr_lr))
        self.sess.run(tf.assign(self.learning_rate_var, curr_lr / factor))
        new_step, new_lr = self.sess.run(
            [tf.train.get_or_create_global_step(), self.optim._learning_rate])
        logging.info('New LR at step {} is {}; old was {}'.format(
            new_step, new_lr, curr_lr))

    def train(self, handle):
        """Run training on model with feed dict."""
        import tensorflow as tf
        feed_dict = self._get_feed_dict(True)
        feed_dict['handle:0'] = handle
        train_step = self.graph.get_operation_by_name('train_step')
        if self.log_metadata:
            run_options = tf.RunOptions(
                trace_level=tf.RunOptions.FULL_TRACE,
                output_partition_graphs=True)
            run_metadata = tf.RunMetadata()
            kwargs = {'options': run_options, 'run_metadata': run_metadata}
        else:
            kwargs = {}
            run_metadata = None

        try:
            predictions, summary, acc, loss, _, _, gridded_pairs =\
                self.sess.run(
                    [self.y_out, self.merged, self.accuracy,
                     self.cross_entropy, train_step, self.check_op,
                     self.gridded_pairs],
                    feed_dict=feed_dict, **kwargs)

            # Can access global step and learning rate as shown below:
            # cur_step, cur_lr = self.sess.run([tf.train.get_global_step(), self.optim._learning_rate])
            # if cur_step % 100 == 0:
            #     logging.info('GLOBAL STEP: {}, LR: {}'.format(cur_step, cur_lr))
        except tf.errors.InvalidArgumentError as err:
            logging.error("Caught crash, saving state and dying.")
            logging.error(err)
            saver = tf.train.Saver()
            saver.save(
                self.sess,
                self.tracker.out_dir + '/crash-' + self.model_name + '.ckpt')
            pickle.dump(
                feed_dict,
                open(self.tracker.out_dir + '/crash-input-' + self.model_name +
                     '.pkl', 'w'))
            raise err

        return predictions, acc, loss, summary, run_metadata, gridded_pairs

    def infer_examples(self, batch1, batch2, labels=None):
        """Do inference from batches."""
        with_acc = labels is not None
        feed_dict = self._get_feed_dict(False)
        feed_dict["grid_left:0"] = batch1
        feed_dict["grid_right:0"] = batch2
        if with_acc:
            feed_dict["y:0"] = labels
        res = self.infer(feed_dict, with_acc=with_acc)
        if with_acc:
            # If we have labels, we can't use our dynamic padding since acc
            # will be off. TODO: Fix this?
            assert labels.shape[0] % self.towers == 0
        else:
            # Without accuracy, we allow for padding in input, but the outputs
            # will be padded as well, so we fix that here.
            res = list(res)
            res[0] = res[0][:batch1.shape[0]]
            res = tuple(res)
        return res

    def infer_handle(self, handle):
        """Do inference from feed dict."""
        feed_dict = self._get_feed_dict(False)
        if handle is not None:
            feed_dict["handle:0"] = handle
        # TODO. Allow this to handle uneven batches too.

        return self.infer(feed_dict, with_acc=True)

    def infer(self, feed_dict, with_acc=False):
        if with_acc:
            predictions, summary, acc, loss, gridded_pairs = \
                self.sess.run(
                    [self.y_out, self.merged, self.accuracy,
                     self.cross_entropy, self.gridded_pairs],
                    feed_dict=feed_dict)
        else:
            (predictions, gridded_pairs) = \
                self.sess.run(
                    [self.y_out, self.gridded_pairs],
                    feed_dict=feed_dict)
        # TODO: Handle case where uneven number of examples.
        if with_acc:
            return predictions, acc, loss, summary, gridded_pairs
        else:
            return predictions, gridded_pairs

    def rearrange_output(self, output, source_dict):
        """Flatten source dict to match output of model."""
        sources = []
        for tower in range(self.towers):
            tower_name = _tower_num_with_format(tower, self._is_new_format)
            sources.extend(source_dict[tower_name])

        output_reshuffled = output.copy()
        for out_idx, in_idx in enumerate(sources):
            output_reshuffled[in_idx] = output[out_idx]
        return output_reshuffled

    def save(self, path):
        """Save model to path."""
        import tensorflow as tf
        saver = tf.train.Saver()
        saver.save(self.sess, path)

    def _get_tensor(self, tensor_name):
        """Get tensor by name from model's graph."""
        return self.graph.get_tensor_by_name(tensor_name)

    def _set_tensor_values(self, check_nans=False):
        """Provide shortcuts to important tensors."""
        import tensorflow as tf

        if self.has_seq:
            prefix = "seq/"
        else:
            prefix = ""

        self.graph = tf.get_default_graph()
        self.merged = self._get_tensor('merged:0')
        self.model_name = time.strftime("%Y-%m-%d-%H-%M-%S")
        if check_nans:
            self.check_op = tf.add_check_numerics_ops()
        else:
            # Set it to something that will be evaluated anyways.
            self.check_op = self.merged
        self.cross_entropy = self._get_tensor(prefix + 'all/loss/mean:0')
        self.accuracy = self._get_tensor(prefix + 'all/accuracy/mean:0')
        self.y_out = self._get_tensor(prefix + 'all/out/all:0')
        self.y = self._get_tensor(prefix + 'all/y/all:0')
