"""
Versioned parameter module.

To update params, do the following...

1) Up the current version number.
2) In _create_params, add the new params behind a version flag.
3) In _attempt_patch, add the appropriate patch behind a version flag.
4) In file test_interaction_params, add the appropriate fixture + test.
5) In docstring below, document what the updated params are.
"""

import abc
import logging
import os

import json


class VersionedParams(object):
    """Instance of parameters with a specific version."""

    @classmethod
    def create(cls, inputs, version=None):
        """Create a new param file."""
        if version is None:
            version = cls._curr_version()
        params = cls._create_params(inputs, version)
        if params is None:
            logging.error('Failed to create params!')
            return None

        vp = cls(params, version)

        param_json = inputs['param_json'] if 'param_json' in inputs else \
            params['param_json']
        out_dir = os.path.dirname(param_json)
        if not os.path.exists(out_dir) and out_dir != '':
            os.makedirs(out_dir)

        if not os.path.exists(param_json):
            with open(param_json, 'w') as f:
                json.dump(params, f)
        else:
            # If there is already a param file present, we attempt an
            # update.
            existing_params = cls.load(param_json)
            res = existing_params._update_params(vp, True)
            if not res:
                return None
            vp = existing_params
        return vp

    @classmethod
    def load_updated(cls, param_json, new_version=None):
        """Load updated version of param file."""
        if new_version is None:
            new_version = cls._curr_version()
        return cls.load(param_json, new_version)

    @classmethod
    def load(cls, param_json, new_version=None):
        """Load an existing param file."""
        with open(param_json) as f:
            params = json.load(f)
        version = params['version']
        existing_vp = cls(params, version)

        if new_version is not None:
            # Update params to provided version.
            inputs = existing_vp._get_creation_inputs()
            new_params = cls._create_params(inputs, version=new_version)
            new_vp = cls(new_params, new_version)
            res = existing_vp._update_params(new_vp, False)
            if not res:
                return None
        return existing_vp

    @classmethod
    def _set_or_default(cls, params, inputs, key, default):
        if key not in inputs or inputs[key] is None:
            params[key] = default
        else:
            if type(inputs[key]) != type(default) and \
                    (not isinstance(inputs[key], basestring) or
                     not isinstance(default, basestring)):
                logging.error("Types do not match {:} vs {:} for {:}".format(
                    type(inputs[key]), type(default), key))
                assert 0
            params[key] = inputs[key]

    def save(self, param_json):
        """Save params to file, checking previous file for compatibility."""
        if os.path.exists(param_json):
            existing_params = VersionedParams.load(param_json)
            res = existing_params._update_params(self, True)
            if not res:
                logging.error("Previous param file does not match!")
                return
        with open(param_json, 'w') as f:
            json.dump(self.params, f)

    def __contains__(self, key):
        """If this param files contains the provided key."""
        return key in self.params

    def __getitem__(self, key):
        """Return item stored under given key."""
        return self.params[key]

    def __setitem__(self, key, value):
        """Set item under provided key."""
        self.params[key] = value

    def __init__(self, params, version):
        """
        Initializer.

        Args:
            params (dict string -> object):
                dictionary mapping from parameter key to value
            version (float):
                version of param file to create.
        """
        self.params = params
        self.version = version

    def _update_params(self, new_vp, enforce_equality):
        """
        Update loaded parameter file if necessary.

        If a parameter file was created via an older version of this code,
        there might be missing values that we patch here.  In the worst case,
        we return None to indicate failure to update.
        """
        fatal = False
        if self.params != new_vp.params:
            patched = []
            for key in set(new_vp.params.keys()) | set(self.params.keys()):
                if key == 'version':
                    pass
                elif key not in self.params.keys() and \
                        key in new_vp.params.keys():
                    successful = self._attempt_patch(key)
                    if successful:
                        logging.info(
                            '{} was set to {}'
                            .format(key, self.params[key]))
                        patched.append(key)
                    else:
                        logging.error(
                            '{} was not set previously but is now set to {}'
                            .format(key, new_vp.params[key]))
                        fatal = True

            for key in set(new_vp.params.keys()) & set(self.params.keys()):
                if key == 'version':
                    pass
                elif self.params[key] != new_vp.params[key]:
                    old = self.params[key]
                    good = self._attempt_update(key, new_vp)
                    if good:
                        logging.info(
                            '{} was set to {} but is now set to {}'
                            .format(key, old, new_vp.params[key]))
                    elif enforce_equality:
                        logging.error(
                            '{} was set to {} but is now set to {}'
                            .format(key, old, new_vp.params[key]))
                        fatal = True
                    else:
                        new_vp.params[key] = self.params[key]

            for key in set(new_vp.params.keys()) | set(self.params.keys()):
                if key not in new_vp.params.keys():
                    successful = self._attempt_deletion(key, new_vp)
                    if successful:
                        logging.info('{} was deleted'.format(key))
                    else:
                        logging.error(
                            '{} was set previously to {} but is no longer set'
                            .format(key, self.params[key]))
                        fatal = True
        if fatal:
            logging.error('Fatal error in update parameters, cancelling run!')
            raise RuntimeError
        else:
            self.params['version'] = new_vp.params['version']
            self.version = new_vp.params['version']
            return True

    @abc.abstractmethod
    def _attempt_deletion(self, key, new_vp):
        """Attempt to delete the provided key in params."""
        return False

    @abc.abstractmethod
    def _attempt_update(self, key, new_vp):
        """Check update to the provided key in params."""
        return False

    @abc.abstractmethod
    def _attempt_patch(self, key):
        """Attempt to patch the provided key in params."""
        return False

    @abc.abstractmethod
    def _get_creation_inputs(self):
        """Get arguments used to create the param file from existing params."""
        pass

    @classmethod
    def _create_params(cls, inputs, version):
        """Inititalize and check parameters."""
        return None

    @classmethod
    def _curr_version(cls):
        """Get latest version."""
        return None
