"""
Model parameters for interaction prediction model.

To update params, do the following...

1) Up the current version number.
2) In _create_params, add the new params behind a version flag.
3) In _attempt_patch, add the appropriate patch behind a version flag.
4) In docstring below, document what the updated params are.

== 1.0 update ==

release version.

"""
import src.util.versioned_params as vp


class ModelParams(vp.VersionedParams):
    """Interaction prediction model params."""

    def _attempt_deletion(self, key, new_vp):
        """Attempt to delete the provided key in params."""
        return False

    def _attempt_patch(self, key):
        """Attempt to patch the provided key in params."""
        return False

    def _attempt_update(self, key, new_vp):
        """Check update to the provided key in params."""
        return False

    @classmethod
    def _create_params(cls, inputs, version):
        """Inititalize and check parameters."""

        params = {
            'version': version,
        }
        params['channel_size'] = 4
        cls._set_or_default(params, inputs, 'conv_size', 3)
        cls._set_or_default(params, inputs, 'conv_stride', 1)
        cls._set_or_default(params, inputs, 'batch_norm', False)
        cls._set_or_default(params, inputs, 'dropout', False)
        cls._set_or_default(params, inputs, 'radius_ang', 4)
        cls._set_or_default(params, inputs, 'dense_layers', 2)
        cls._set_or_default(params, inputs, 'border_mode_valid', True)
        cls._set_or_default(params, inputs, 'num_filters', [32])
        num_convs = len(params['num_filters'])
        cls._set_or_default(params, inputs, 'resolution', 1.0)
        cls._set_or_default(params, inputs, 'max_pool_positions',
                            [-1 for n in range(num_convs)])
        cls._set_or_default(params, inputs, 'max_pool_sizes',
                            [2 for n in range(num_convs)])
        cls._set_or_default(params, inputs, 'max_pool_strides',
                            [2 for n in range(num_convs)])
        cls._set_or_default(params, inputs, 'tower_fc_nodes', [256])
        cls._set_or_default(params, inputs, 'top_fc_nodes',
                            [512 for n in range(params['dense_layers'])])
        return params

    def _get_creation_inputs(self):
        """Get arguments used to create the param file from existing params."""
        inputs = dict()
        return inputs

    @classmethod
    def _curr_version(cls):
        """Current version of params."""
        return 1.0


def init_params(args, version=None):
    return ModelParams.create(vars(args), version=version)


def load_params(param_json, new_version=None):
    return ModelParams.load_updated(param_json, new_version)
