""" Parameters for single training run of interaction prediction model.

To update params, do the following...

1) Up the current version number.
2) In _create_params, add the new params behind a version flag.
3) In _attempt_patch, add the appropriate patch behind a version flag.
4) In docstring below, document what the updated params are.

== 1.0 update ==

release version.

"""
import random

import src.util.versioned_params as vp


class TrainParams(vp.VersionedParams):
    """Interaction training params."""

    @classmethod
    def _create_params(cls, inputs, version):
        """Inititalize and check parameters."""
        params = {
            'version': version,
            'dataset_tfrecords': inputs['dataset_tfrecords'],
            'num_training': inputs['num_training'],
            'num_validation': inputs['num_validation']
        }

        params['stop_criteria'] = dict(inputs['stop_criteria'])
        sc = params['stop_criteria']
        for key in sc.keys():
            if not isinstance(sc[key], basestring):
                continue
            if key == 'converge':
                sc[key] = sc[key] == 'True'
            if key == 'train_threshold':
                sc[key] = float(sc[key])
            if key == 'val_threshold':
                sc[key] = float(sc[key])
            if key == 'lr':
                sc[key] = float(sc[key])

        if len(params['stop_criteria']) == 0:
            # Default is converge=True
            params['stop_criteria']['converge'] = True

        cls._set_or_default(params, inputs, 'num_directions', 20)
        cls._set_or_default(params, inputs, 'num_rolls', 20)
        cls._set_or_default(params, inputs, 'batch_size', 40)
        cls._set_or_default(params, inputs, 'optimizer', 'RMS')
        cls._set_or_default(params, inputs, 'learning_rate', 0.01)
        cls._set_or_default(params, inputs, 'nesterov', False)
        cls._set_or_default(params, inputs, 'towers', 1)
        cls._set_or_default(params, inputs, 'check_nans', False)
        cls._set_or_default(params, inputs, 'max_epochs', -1)
        cls._set_or_default(params, inputs, 'rolls_per_pass',
                            params['num_rolls'])
        cls._set_or_default(params, inputs, 'val_dataset_tfrecords',
                            params['dataset_tfrecords'])
        cls._set_or_default(params, inputs, 'num_interleaved', 40)
        cls._set_or_default(params, inputs, 'prune_file_training', '')
        cls._set_or_default(params, inputs, 'keep_file_training', '')
        cls._set_or_default(params, inputs, 'prune_file_validation', '')
        cls._set_or_default(params, inputs, 'keep_file_validation', '')
        cls._set_or_default(params, inputs, 'shuffle_buffer', 10000)
        cls._set_or_default(params, inputs, 'seed',
                            random.randint(0, 10000000))
        cls._set_or_default(params, inputs, 'lr_decay_type', 'none')
        cls._set_or_default(params, inputs, 'lr_decay_steps', 100000)
        cls._set_or_default(params, inputs, 'lr_decay_rate', 0.96)
        cls._set_or_default(params, inputs, 'lr_decay_staircase', False)
        cls._set_or_default(params, inputs, 'loose', False)
        cls._set_or_default(params, inputs, 'seq_src', "")
        cls._set_or_default(params, inputs, 'keep_file_pairs_training', '')
        cls._set_or_default(params, inputs, 'keep_file_pairs_validation', '')
        return params

    def _get_creation_inputs(self):
        """Get arguments used to create the param file from existing params."""
        inputs = dict()
        inputs['num_training'] = self.params['num_training']
        inputs['num_validation'] = self.params['num_validation']
        inputs['dataset_tfrecords'] = self.params['dataset_tfrecords']
        inputs['stop_criteria'] = self.params['stop_criteria'].items()

        return inputs

    @classmethod
    def _curr_version(cls):
        """Current version of params."""
        return 1.0


def init_params(args, version=None):
    return TrainParams.create(vars(args), version=version)


def load_params(param_json, new_version=None):
    return TrainParams.load_updated(param_json, new_version)
