
class AbsDirnModifier(object):
    '''
    Interface for any method that seeks to modify the input descent direction, aka,
    negative gradient, using newtonian info. Eg. plain netwon method, or quasi newton method etc.
    '''
    @staticmethod
    def get_name():
        raise NotImplementedError("this is an abstract class defining the interface.")

    def initialize(self):
        pass
    
    def __init__(self, obj, ismx=False): 
        self.objective = obj
        self.is_min = not ismx
        self.objmult=1.0
        if not self.is_min: self.objmult=-1.0
        
    def modify_direction(self, descent_dirn):
        raise NotImplementedError("this is an abstract class defining the interface.")


    
class NoModification(AbsDirnModifier):
    
    _name='none'
    def get_name(): 
        return NoModification._name

    def __init__(self, objective, ismx=False, argreg=None, argdict=None): #, ismx=False):        
        super(NoModification,self).__init__(objective, ismx)
            
    def fill_args_registry(arg_reg):        
        pass
    
    def modify_direction(self, descent_dirn):
        pass


class NewtonDirection(AbsDirnModifier):
    
    _name='newton'
    def get_name(): 
        return NewtonDirection._name

    def __init__(self, objective, ismx=False, argreg=None, argdict=None): #, ):        
        super(NewtonDirection,self).__init__(objective, ismx)
            
    def fill_args_registry(arg_reg):        
        pass
    
    def modify_direction(self, descent_dirn):
        r'''
            Note: no matter what the direction of optimization is, Newton's method iterations are
                    ALWAYS x_{k+1} = x_k  -  (f'(x_k) / f''(x_k))  to find the root of f',
            So, change the sign of the gradient (descent dirn is presumed to be this)
            when maximizing.
        '''
        self.objective.get_hessian_inverse_vector_prod( self.objmult* descent_dirn, descent_dirn)
    


ModClassList = [NoModification, NewtonDirection]
_arg_reg_base = 'dirnmodifier'
_argname_type='type'

from util_py.arg_parsing  import ArgumentRegistry

def get_dirnmodifier_arg_registry(extra_classes=None) :

    argdefval_type = [c.get_name() for c in ModClassList]

    if extra_classes is not None:
        for c in extra_classes:
            argdefval_type.append(c.get_name())
        
    arg_reg = ArgumentRegistry(_arg_reg_base)

    arg_reg.register_str_arg(_argname_type,
                             'which momentum type to use',
                             argdefval_type[0],
                             argdefval_type)

    for cl in ModClassList:          
        cl.fill_args_registry(arg_reg)

    if extra_classes is not None:
        for cl in extra_classes:
            cl.fill_args_registry(arg_reg)
            
    return arg_reg    


def instantiate_dirnmodifier(objective, arg_dict, addlcls=None):

    arg_registry = get_dirnmodifier_arg_registry(addlcls)
    
    # read in the args
    modnm = arg_dict[arg_registry.get_arg_fullname(_argname_type)]

    for c in ModClassList:
        if modnm == c.get_name():
            return c(objective, arg_registry, arg_dict)

    if addlcls is not None:
        for c in addlcls:
            if modnm == c.get_name():
                return c(objective, arg_registry, arg_dict)


    raise ValueError('have not implemented momenturm of  type \'{}\' yet.'.format(momnm))
        
