from .abs_optim import AbsOptim
from .sampling import get_sampling_rate_arg_registry
from .momentum import get_momentum_arg_registry
from .learning_rate import get_learning_rate_arg_registry


import logging

_LOGGER = logging.getLogger(name='optim.sgd')



class StochasticGradientDescent(AbsOptim):
 
    @staticmethod
    def get_arg_registries(stct_xtra=None, lr_xtra=None, smprt_xtra=None, mom_xtra=None):
        r'''
        Get a list of all arg registries needed by SGD
        '''
        arg_list=super(StochasticGradientDescent, StochasticGradientDescent).get_arg_registries(stct_xtra)
         
        #unified arg reg for learning rates
        arg_list.append(get_learning_rate_arg_registry(lr_xtra))
        
        #unified arg reg for sampling rates
        arg_list.append(get_sampling_rate_arg_registry(smprt_xtra))
    
        #unified arg reg for momentum
        arg_list.append(get_momentum_arg_registry(mom_xtra))

        return arg_list    


    def __init__(self, objective, stopcrit, algstate, 
                 sampler, samrat, lrate, mom, 
                 arg_dict=None, ismx=False):
        
        self.__momentum = mom        

        super(StochasticGradientDescent, self).__init__(
                objective,stopcrit, algstate, arg_dict, ismx)
        
        self.__lrate= lrate
        self.__sampler = sampler
        self.__sampling_rate = samrat

    def set_minimization(self, ismin):
        super(StochasticGradientDescent,self).set_minimization(ismin)
        self.__momentum.set_minimization(ismin)

    def initialize(self):
        super(StochasticGradientDescent,self).initialize()
        
        self.__lrate.initialize()
        self.__momentum.initialize()
        self.__sampler.initialize()
        self.__sampling_rate.initialize()
        

    def step(self):
        '''
        loss is calculated here and then gradient calculated.
        
        return value: it is traditional to return the calculated training loss
        '''
  
        # get next sample 
        nsamp = self.__sampling_rate.next_sample_size(self.algo_state)
    

        self.curr_samples = self.__sampler.next_sample(nsamp)                    
        
        if _LOGGER.isEnabledFor(logging.DEBUG):
            _LOGGER.debug("Itr {:4d} minibatch size {:4d} out of {}".format(
                    self.algo_state.n_itr.value,nsamp, self.curr_samples.shape))

        # pass this on to the objective
        self.objective.set_samples(self.curr_samples, nsamp)
        
        # evaluate the objective value and compute the gradient
        objval = self.objective.evaluate_fn_and_derivatives()
            
        #self.objective.evaluate_gradient()
        

        # calculate the next steplength / learning rate . Note that this gets
        # computed using the previous value of the iterate count!
        steplength = self.__lrate.get_stepsize(self.algo_state)

        newdirn, newstep = self.__momentum.take_step(steplength)

        self.algo_state.output_current(nsamp , objval, steplength, newdirn)

        return objval

  
