from util_py.arg_parsing import ArgumentRegistry
from math import ceil
import logging

_LOGGER = logging.getLogger('optim.stopctrn')

class AbsStoppingCriterion(object):
    '''
    This is the base class of all stopping criterion. Only thing this does is to establish
    the pattern
    '''
    @staticmethod
    def get_name():
        raise NotImplementedError("this is an abstract class defining the interface.")
    
        
    def should_stop(self, algodata):
        raise NotImplementedError("this is an abstract class defining the interface.")
    
    def __init__(self):
        self.is_minimization = True
        
        
    def is_min(self, ismin):
        self.is_minimization = ismin
        
    def initialize(self):
        pass
    

    
class StopMaxIterations(AbsStoppingCriterion):
    ''' 
    This stops the algo when max-iterations threshold is exceeded
    '''
    stctname= 'maxiters'
    argname = "{}_{}".format( stctname, "threshold")
    argdefval= 100

    def fill_args_registry(arg_reg):
        arg_reg.register_int_arg(StopMaxIterations.argname,
                         'max iterations if we are to stop at this iter',
                         StopMaxIterations.argdefval)

    @staticmethod
    def get_name():
        return StopMaxIterations.stctname

    def __init__(self, arg_registry=None, arg_dict=None, thd=100):
        if (arg_dict is None ) or (arg_registry is None):
            self.max_iters = thd
        else:
            self.max_iters = arg_dict[arg_registry.get_arg_fullname(self.argname)]

        super(StopMaxIterations,self).__init__()
    

    def should_stop(self, algodata):
        retval= (algodata.n_itr.value > self.max_iters)
        if retval:
            if _LOGGER.isEnabledFor(logging.INFO):
                _LOGGER.info("stopping cuz num iter {:10d} > threshold {:10d}".format(
                        algodata.n_itr.value, self.max_iters))
        else:
            if _LOGGER.isEnabledFor(logging.DEBUG):
                _LOGGER.debug("continuing cuz num iter {:10d} > threshold {:10d}".format(
                        algodata.n_itr.value, self.max_iters))
        return retval
    
    
class StopMaxTime(AbsStoppingCriterion):
    '''
    This stops the algo when cumul time exceeds threshold 
    '''
    stctname='maxtime'
    argname = "{}_{}".format(stctname, "threshold")
    argdefval= 10000.0    

    def fill_args_registry(arg_reg):
        arg_reg.register_float_arg(StopMaxTime.argname,'max cputime to run for',
                                   StopMaxTime.argdefval)

    def __init__(self, arg_registry, arg_dict, thd=1.e5):
        if (arg_dict is None ) or (arg_registry is None):
            self.max_cputime = thd
        else:
            self.max_cputime = arg_dict[arg_registry.get_arg_fullname(self.argname)]

        super(StopMaxTime,self).__init__()

    @staticmethod
    def get_name():
        return StopMaxTime.stctname

    def should_stop(self, algodata):
        retval= (algodata.c_time.value > self.max_cputime)
        if retval:
            if _LOGGER.isEnabledFor(logging.INFO):
                _LOGGER.info("stopping cuz cum time {:7.2e} > threshold {:7.2e}".format(
                        algodata.c_time.value, self.max_cputime))
        else:
            if _LOGGER.isEnabledFor(logging.DEBUG):
                _LOGGER.debug("continuing cuz cum time {:7.2e} > threshold {:7.2e}".format(
                        algodata.c_time.value, self.max_cputime))
        return retval

    
class StopMaxCumSamples(AbsStoppingCriterion):
    '''
    This stops the algo when cumul samples exceeds a budget
    '''
    stctname='maxcumsamples'

    @staticmethod
    def get_name():
        return StopMaxCumSamples.stctname

    argname = "{}_{}".format(stctname, "budget")
    argdefval= 1e7    

    def fill_args_registry(arg_reg):
        arg_reg.register_float_arg(StopMaxCumSamples.argname,'max number of total samples to run till',
                                   StopMaxCumSamples.argdefval)

    def __init__(self, arg_registry, arg_dict, budge=1e7):
        if (arg_dict is None ) or (arg_registry is None):
            self.max_cumulsamples = budge
        else:
            self.max_cumulsamples = arg_dict[arg_registry.get_arg_fullname(self.argname)]

        super(StopMaxCumSamples,self).__init__()


    def should_stop(self, algodata):
        retval= (algodata.c_samp.value > self.max_cumulsamples)
        if retval:
            if _LOGGER.isEnabledFor(logging.INFO):
                _LOGGER.info("stopping cuz cum samples {:10d} > threshold {:10d}".format(
                        algodata.c_samp.value, self.max_cumulsamples))
        else:
            if _LOGGER.isEnabledFor(logging.DEBUG):
                _LOGGER.debug("continuing cuz cum samples {:10d} < threshold {:10d}".format(
                        algodata.c_samp.value, self.max_cumulsamples))
        return retval

    
    

from collections import deque
        
class StopMinObjRelChange(AbsStoppingCriterion):
    '''
    This is an important stopping criterion, where the algo stops if the running avg 
    over past window of obj vals does change as fast as we'd wish it.
    
    There are two buffers, the 'history' buffer containing the past iterate's objective values
    and the 'current' buffer that contains the current or the most recent iterates' values.
    So, buffer of all recent iterates, rightmost latest = { history | current }
    
    This stopping rule compares the avg value of 'current' with 'history' and stops if:
        (minimization)   curr_avg + add_tol > (1-mult_tol)*hist_avg
        (maximization)   curr_avg - add_tol < (1+mult_tol)*hist_avg
    
    where 'add_tol' and 'mult_tol' are user provided. 
    The idea thus is that if the current avg didn't get much better than the recent
    history, then time to stop.
        
    '''
    stctname ='relchange'
    @staticmethod
    def get_name():
        return StopMinObjRelChange.stctname
    
    argname_histsize = "{}_{}".format(stctname,"histsize")
    argdefval_histsize =  100
    argname_histplit = "{}_{}".format(stctname,"histsplit")
    argdefval_histsplit =  0.8
    argname_multtol = "{}_{}".format(stctname,"multtol")
    argdefval_multtol = 0.001    
    argname_addtol = "{}_{}".format(stctname,"addtol")
    argdefval_addtol = 1e-9    

    argname_skipinitial, argdefval_skipinitial = "{}_{}".format(stctname,'skip_initial'), 0

#    argname_type = "{}_{}".format(stctname,"type")
#    argval_type = ['train_rrm', 'test_misclassification']
#    argdefval_type =  argval_type[0]
    
    
    def fill_args_registry(arg_reg):
        
        arg_reg.register_int_arg(StopMinObjRelChange.argname_histsize, 
                                 'size of history to compare current test loss against',
                                 StopMinObjRelChange.argdefval_histsize)

        arg_reg.register_float_arg(StopMinObjRelChange.argname_multtol,
                                   'threshold of min fraction improvement expected',
                                   StopMinObjRelChange.argdefval_multtol)

        arg_reg.register_float_arg(StopMinObjRelChange.argname_addtol,
                                   'threshold of min additive improvement expected',
                                   StopMinObjRelChange.argdefval_addtol)

        arg_reg.register_float_arg(StopMinObjRelChange.argname_histplit,
                                   'fraction of trailing history that needs to be improved by leading history (0,1)',
                                   StopMinObjRelChange.argdefval_histsplit)
        
        arg_reg.register_int_arg(StopMinObjRelChange.argname_skipinitial,
                                 'how many initial iterations to skip before applying rule',
                                 StopMinObjRelChange.argdefval_skipinitial)

#        arg_reg.register_str_arg(StopMinObjRelChange.argdefval_type,
#                                 'type of data to monitor for minimum relative change',
#                                 StopMinObjRelChange.argdefval_type,
#                                 StopMinObjRelChange.argval_type)

    def __init__(self, arg_registry=None, arg_dict=None, hsz=100, hspl=0.8, initskip=0,
                 mtol=1e-3, atol=1e-9):

        if (arg_dict is None ) or (arg_registry is None):
            self.buffersiz=hsz
            self.trailfrac = hspl
            self.mult_tol_frac = mtol
            self.add_tol = atol
            self.initial_skip = initskip
        else:
            self.buffersiz =  arg_dict[arg_registry.get_arg_fullname(
                    self.argname_histsize)]
    
            self.trailfrac = arg_dict[arg_registry.get_arg_fullname(
                    self.argname_histplit)]
    
            self.mult_tol_frac =  arg_dict[arg_registry.get_arg_fullname(
                    self.argname_multtol)]

            self.add_tol =  arg_dict[arg_registry.get_arg_fullname(
                    self.argname_addtol)]

            self.initial_skip = arg_dict[arg_registry.get_arg_fullname(
                    self.argname_skipinitial)]

        if (self.trailfrac < 0.001) : self.trailfrac= 0.001
        if (self.trailfrac > 0.999) : self.trailfrac = 0.999
        
        if (self.trailfrac < 0.5) : 
            self.trailbuffsiz = int(ceil(self.buffersiz * self.trailfrac))
            self.leadbuffsiz = self.buffersiz - self.trailbuffsiz
        else:
            self.leadbuffsiz = int(ceil(self.buffersiz * (1.0-self.trailfrac)))
            self.trailbuffsiz = self.buffersiz - self.leadbuffsiz
            
        self.history_trailing = deque([0.0]*self.trailbuffsiz)
        self.hist_trailing_sum = 0.0 

        self.history_leading = deque([0.0]*self.leadbuffsiz)
        self.hist_leading_sum = 0.0
        
        self.num_appended = 0

        super(StopMinObjRelChange,self).__init__()
        self.optmult=1.0
        
        self.add_tol_adj = self.optmult * self.add_tol

        self.shd_initialize=True

    def is_min(self, m):
        super(StopMinObjRelChange,self).is_min(m)        
        if not self.is_minimization:
            self.optmult=-1.0
        self.add_tol_adj = self.optmult * self.add_tol


    def initialize(self):

        if self.shd_initialize:
            for n in range(self.trailbuffsiz):
                self.history_trailing.popleft()
                self.history_trailing.append(0.0)
            self.hist_trailing_sum=0.0
            
            for n in range(self.leadbuffsiz):
                self.history_leading.popleft()
                self.history_leading.append(0.0)
            self.hist_leading_sum=0.0
    
            self.num_appended = 0
        
        
    def should_stop(self, algodata):

        if algodata.n_itr.value < self.initial_skip:
            return False

        lossvalu = algodata.objval.value
        
        # update history
        xfr = 0.0
        if len(self.history_leading) > 0:
            xfr = self.history_leading.popleft()
            self.hist_trailing_sum += xfr - self.history_trailing.popleft()
            self.history_trailing.append(xfr)
            
        self.hist_leading_sum += lossvalu - xfr 
        self.history_leading.append(lossvalu)
        self.num_appended += 1
        retval = (self.num_appended >= self.buffersiz) 
        
        if retval: 
            leadavg= self.hist_leading_sum / self.leadbuffsiz
            histavg = self.hist_trailing_sum / self.trailbuffsiz

            ''' This stopping rule compares the avg value of 'current' with 'history' and stops if:
                    (minimization)   curr_avg + add_tol > (1-mult_tol)*hist_avg
                    (maximization)   curr_avg - add_tol < (1+mult_tol)*hist_avg
            '''
            if histavg < 0.0:
                mult_tol = (1. + self.optmult * self.mult_tol_frac)
            else:
                mult_tol = (1. - self.optmult * self.mult_tol_frac)
            
            rhs = leadavg + self.add_tol_adj - mult_tol * histavg 
            retval = self.optmult * (rhs) > 0
        
            if retval:
                _LOGGER.info("itr {}: ({:2.0f}) * ({:7.4e} = [ (curravg {:7.4e} + addtol {:7.4e} - (multtol {:7.4e} * histavg {:7.4e}) ) ]) > 0".format(
                    algodata.n_itr.value, self.optmult, rhs, leadavg, self.add_tol_adj, mult_tol, histavg))
            else:
                _LOGGER.debug("itr {}: ({:2.0f}) * ({:7.4e} = [ (curravg {:7.4e} + addtol {:7.4e} - (multtol {:7.4e} * histavg {:7.4e}) ) ]) < 0,  curravg {} histavg {}".format(
                    algodata.n_itr.value, self.optmult, rhs, leadavg, self.add_tol_adj, mult_tol, histavg, leadavg, histavg))

        return retval

class StopSet(AbsStoppingCriterion):
    '''
    This lets us capture multiple simultaneous stop criterion running, either till 
    all agree or one agrees to stop the run.
    '''
    stctname='set'
    @staticmethod
    def get_name():
        return StopSet.stctname

    argname = "{}_{}".format(stctname, "criteria")
    argdefval= ""
    sep='|'
    
    argnametype = "{}_{}".format(stctname, "type")
    argvaltypes= ("all","any")
    argdefvaltype=argvaltypes[1]
    
    def fill_args_registry(arg_reg):
        arg_reg.register_str_arg(StopSet.argname,
                                   'any or all of the listed stopping criteria, separated by "{}"'.format(
                                           StopSet.sep),
                                   StopSet.argdefval)

        arg_reg.register_str_arg(StopSet.argnametype,
                                 'shd we satisfy one or all',
                                 StopSet.argdefvaltype,
                                 StopSet.argvaltypes)
        
    

    def __init__(self, arg_registry=None, arg_dict=None, clslst=None, crtslst=None,typ="any"):

        
        arggivn= (arg_dict is not None ) and (arg_registry is not None)
        if arggivn:
            lstr = arg_dict[arg_registry.get_arg_fullname(self.argname)]
            if lstr is None or len(lstr.strip())==0:
                raise ValueError("set stopping criterion did not get any criteria!")
                
            lstr=lstr.strip()
            
            # generate the list of cireteria from here
            crtsnm = lstr.split(self.sep)

            crtslst = []
            for l in crtsnm:
                lv=l.strip()
                if len(lv) <=0: continue
                if lv == StopSet.get_name():
                    raise ValueError("StopSet may not have another StopSet in it.")
                else:
                    for c in clslst:
                        if lv == c.get_name():
                            crtslst.append( c(arg_registry, arg_dict))

            typ = arg_dict[arg_registry.get_arg_fullname(self.argnametype)]

        self.criteria = crtslst
        if arggivn and  len(self.criteria) <=0:
            raise ValueError("stopping criteria set did not get any criteria!")

        self.type_any = True
        if typ != "any":
            self.type_any = False

        super(StopSet,self).__init__()

    def is_min(self, m):
        super(StopSet,self).is_min(m)        
        for c in self.criteria:
            c.is_min(m)


    def initialize(self):
        for c in self.criteria:
            c.initialize()

    def add_criteria(self, crls):
        print("Crit list is befor: {}".format(self.criteria))
        for c in crls:
            self.criteria.append(c)
        print("Crit list is after: {}".format(self.criteria))
            
        
    def should_stop(self, algodata):
        retval = False 
        if self.type_any:
            for c in self.criteria:
                if c.should_stop(algodata): 
                    retval = True
                    break
        else:
            retval = True
            for c in self.criteria:
                if not c.should_stop(algodata): 
                    retval =  False
                    break
                
        return retval


StopCtrnClassList = [StopMaxIterations, 
                      StopMaxTime, 
                      StopMinObjRelChange,
                      StopSet,
                      StopMaxCumSamples]

# the stopping critera allowed in the algorithm
absstctname = 'stopctrn'    
argname_stct_type = 'type'
    

def get_stop_criterion_arg_registry(extra_classes=None) :

    argdefval_stct_type = [c.get_name() for c in StopCtrnClassList]
        
    if extra_classes is not None:
        for c in extra_classes:
            argdefval_stct_type.append(c.get_name())

    arg_reg = ArgumentRegistry(absstctname)

    arg_reg.register_str_arg(argname_stct_type,
                             'which criterion to use',
                             argdefval_stct_type[0],
                             argdefval_stct_type)

    for cl in StopCtrnClassList:          
        cl.fill_args_registry(arg_reg)
    
    if extra_classes is not None:
        for cl in extra_classes:
            cl.fill_args_registry(arg_reg)

    return arg_reg


    
def instantiate_stopping_criterion(arg_dict, addlcls=None):

    arg_registry = get_stop_criterion_arg_registry(addlcls)
    
    # read in the args
    stopnm = arg_dict[arg_registry.get_arg_fullname(AbsStoppingCriterion.argname_stct_type)]

    clslist=StopCtrnClassList.copy()
    
    if addlcls is not None:
        for nm in addlcls:
            clslist.append(nm)
            
    if stopnm == StopSet.get_name():
        return StopSet(arg_registry,arg_dict,clslist)
    else:
        for c in clslist:
            if stopnm == c.get_name():
                return c(arg_registry, arg_dict)
    
    raise ValueError('have not implemented stop criterion \'{}\' yet.'.format(stopnm))
    

