import sys
import logging

class _AbsArg:
    def __init__(self,nm,hlpstr, defval):
        self.argname = nm
        self.help_str = hlpstr 
        self.default = defval
        
        
class _FloatArg(_AbsArg):
    def __init__(self,nm, helpstr, defval):
        super(_FloatArg, self).__init__(nm, helpstr, defval)
        
class _IntArg(_AbsArg):
    def __init__(self,nm, helpstr, defval):
        super(_IntArg,self).__init__(nm, helpstr, defval)

class _StrArg(_AbsArg):
    def __init__(self,nm, helpstr, defval, choices=None):
        super(_StrArg,self).__init__(nm, helpstr, defval)

        # extra work!   
        self.choices = None
        if choices is not None:
            # check if def from choices, value error raised if not found
            try:
                dumindx = choices.index(self.default)
            except ValueError as e:
                print('arg {}:-> default choice \'{}\' was not in list of valid choices {}: error {}'.format(nm, defval,choices,e), file=sys.stderr)
                raise e
            self.choices = choices
            self.help_str += ". Choices: {}".format(self.choices)


class _FlagArg(_AbsArg):
    def __init__(self,nm, helpstr, defval, action):
        super(_FlagArg,self).__init__(nm, helpstr, defval)
        self.action = action    
        if self.default and (self.action.find('true') != -1):
            raise ValueError('flag argument {} with default of \'{}\' should not also have action \'{}\' '.format(
                    self.argname, self.default, self.action))
        elif (not self.default) and (action.find('false') != -1): 
            raise ValueError('flag argument {} with default of \'{}\' should not also have action \'{}\' '.format(
                    self.argname, self.default, self.action))

class ArgReInsertionError(Exception):
    def __init__(self, message):

        # Call the base class constructor with the parameters it needs
        super(ArgReInsertionError, self).__init__(message)

        # Now for your custom code...
        #self.errors = errors

class ArgumentRegistry(object):
                
    def __init__(self, prefix):
        self.__prefix = prefix
        self.__registry = {}

    def _chk_input_nm(self,nm):
        if nm is None or (len(nm.strip())<=0):
            raise ValueError('key \'{}\' must be a non-empty string.')
        if nm in self.__registry:
            raise ArgReInsertionError('registry with prefix \'{}\' already has a key \'{}\', you tried pushing one in again.'.format(
                    self.__prefix, nm))
        
    def get_classname(self): 
        return self.__prefix
        
    def register_float_arg(self,nm, helpstr, defval):
        self._chk_input_nm(nm)
        self.__registry[nm] = _FloatArg(nm,helpstr,defval)

    def register_int_arg(self,nm, helpstr, defval):
        self._chk_input_nm(nm)
        self.__registry[nm] = _IntArg(nm,helpstr,defval)

    def register_str_arg(self,nm, helpstr, defval, choices=None):
        self._chk_input_nm(nm)
        self.__registry[nm] = _StrArg(nm,helpstr,defval, choices)
    
    def register_flag_arg(self,nm, helpstr, defval, action):
        self._chk_input_nm(nm)
        self.__registry[nm] = _FlagArg(nm,helpstr,defval, action)

    def get_arg_fullname(self,knm):
        return '{:s}_{:s}'.format(self.__prefix, knm)
        
    def add_args_to_parser(self, parser):

        for knm in self.__registry:
            vl = self.__registry[knm]
            anm = '--'+self.get_arg_fullname(knm)           

            if type(vl) is _FloatArg:
                parser.add_argument(anm, type=float, default= vl.default,
                                    help=vl.help_str)                
            elif type(vl) is _IntArg:
                parser.add_argument(anm, type=int, default= vl.default,
                                    help=vl.help_str)                
            elif type(vl) is _StrArg:
                if vl.choices is not None:
                    parser.add_argument(anm, type=str, default=vl.default, 
                                        help=vl.help_str, choices=vl.choices)
                else:
                    if vl.default:
                        parser.add_argument(anm, type=str, default=vl.default, 
                                            help=vl.help_str)
                    else:
                        parser.add_argument(anm, type=str,  help=vl.help_str)
                            
            elif type(vl) is _FlagArg:
                parser.add_argument(anm,default=vl.default, action=vl.action,
                            help=vl.help_str)
            else:
                ValueError("didn't understand type {} of key {}".format(type(vl), anm))

    def get_arg_default_dict(self):
        retval={}
        for knm in self.__registry:
            retval[self.get_arg_fullname(knm)] = self.__registry[knm].default
        return retval
    

    def log_args_description(self, arg_dict, lgr):
        #vargs=vars(args)
        # print(vargs)
        if lgr.isEnabledFor(logging.INFO):
            for knm in self.__registry:
                vl = self.__registry[knm]
                anm = self.get_arg_fullname(knm)           
                lgr.info('INITARGS: --{}={} ({}, with default {})'.format(
                        anm,arg_dict[anm],vl.help_str, vl.default))



import argparse
# add random values
_arg_dist_choices = ['uniform']

def _parse_and_return_arg_dict(input_args, all_regs):
    r'''
    this function trawls through all the functions applying the 
    arg-parser to them (as the one and only arg) and then goes through
    all the input_args or comdline args (if input_args is None),

    The list 'input_args; contains entries formatted exactly as one would 
    inputs in the commandline.
    
    Finally, it returns a dictionary version of the args with argument 
    key-value pairs.
    '''    
    parser = argparse.ArgumentParser()
    
    if all_regs:
        # we maintain a list of all args seen before
        prev_ = {}
        for reg in all_regs:
            regc= reg.get_classname()
            if regc not in prev_:
                reg.add_args_to_parser(parser)
                prev_[regc]=regc
                
    args = parser.parse_args(input_args)
    
    arg_dict = vars(args)
    
    return arg_dict


def parse_for_arguments(reglist, classlist=None, input_args_override=None):
    ''' this high-level function allows users to specify 
          - reglist : list of registreis, must be nonempty, for instance the logging registry
                      can be added from logging_get_args_registry(flvl, clvl)
          - classlist: a list of classes which can be polled using <clsnm>.get_arg_registries()
                       for all associated registreis. Can be empty or None
          - input_args_override : a dist of args to use that supersede any commandline or default entries
                        order is default < commandline < input_args

    We DON'T eliminate multiple copies of the same arg reg class, because some could 
    be produced by derived classes who may have extra arguments!!
                        
    '''
    
    # the list of registries passed in
    if reglist is None or len(reglist)<1:
        raise ValueError("need to pass in at least one arg registry, this can't be None.")

    regfnl=reglist.copy()

    # getting registries from classes
    if classlist is not None:      
        for clss in classlist:
            for ag in clss.get_args_registries():
                regfnl.append(ag)
                
        
    # now parse all
    arg_dict = _parse_and_return_arg_dict(None, regfnl)

    # now to override all the args user requested
    if input_args_override is not None:
        for ar in input_args_override:
            if ar not in arg_dict:
                raise RuntimeError('Did not find key \'{}\' in system\'s args'.format(ar))
            arg_dict[ar] = input_args_override[ar]

    return arg_dict
        


def add_arg_for_random(arg_reg, idv):
    arg_reg.register_flag_arg(idv+'_israndom', idv+', shd output be random?', False,'store_true') 
    arg_reg.register_str_arg(idv+'_rand_type', idv+', choice of distn, if needed',
                                  _arg_dist_choices[0], _arg_dist_choices)
    arg_reg.register_str_arg(idv+'_rand_params',
                             idv+', comma sep. floats of parameters of chosen distn, eg "uniform" needs "lo,high"',
                             '0.0,1.0')    

# Add constants
def add_arg_for_constant(arg_reg, idv, defval):
    arg_reg.register_float_arg(idv+'_constant', idv+', value to output constantly',defval)


