"""Argument processing and routing to appropriate subcommand."""
import argparse as ap
import logging
import os
import socket
import subprocess
import sys

sys.path.insert(0, 'lib')


def add_parsers():
    """Add parsers from callers namepsace."""
    import inspect
    frame = inspect.currentframe()
    try:
        caller_locals = frame.f_back.f_locals
        for key, value in caller_locals.items():
            if callable(value) and value.__module__ == __name__ and \
                    key[0:3] == "add":
                value()
    finally:
        del frame


def pair_gen(type1, type2):
    def pairs(s):
        """Pair argparse type."""
        try:
            tok1, tok2 = s.split(',')
            reg1, reg2 = type1(tok1), type2(tok2)
            return reg1, reg2
        except Exception:
            raise ap.ArgumentTypeError(
                "Need to pass comma-separated pairs of type ({:}, {:})!"
                .format(type1.__name__, type2.__name__))
    return pairs


def main():
    """Process all arguments."""
    p = ap.ArgumentParser(description='Protein Complex Prediction.')

    pp = ap.ArgumentParser(add_help=False)
    pp.add_argument('-l', metavar='log', type=str,
                    help='log file to output to (default: %(default)s)')

    subparsers = p.add_subparsers(title='subcommands',
                                  metavar='SUBCOMMAND   ',
                                  help='DESCRIPTION')

    def add_interact_parser():
        import src.learning.interaction.train as train
        ip = subparsers.add_parser('interact',
                                   description='Interaction Prediction.',
                                   help='predict subgrid interactions',
                                   parents=[pp])
        ip.set_defaults(func=train.train)
        ip.add_argument(metavar='model.json', dest='model_json', type=str,
                        help='location of model param file.')
        ip.add_argument(metavar='train.json', dest='train_json', type=str,
                        help='location of training param file.')
        ip.add_argument(metavar='model_dir', dest='model_dir', type=str,
                        help='location to send model files to.')
        ip.add_argument('-q', '--seqmodel', metavar='seqmodel.json',
                        dest='seqmodel_json', type=str, default='',
                        help='location of seqmodel param file, if retraining '
                        'with sequence conservation')

    def add_predict_parser():
        import src.learning.interaction.test as test
        rp = subparsers.add_parser(
            'predict',
            description='Pair prediction.',
            help='predicts interactions from trained model',
            parents=[pp])
        rp.set_defaults(func=test.test_model_main)
        rp.add_argument(metavar='model.json', dest='model_json', type=str,
                        help='location of model param file.')
        rp.add_argument(metavar='test.json', dest='test_json', type=str,
                        help='location of testing param file.')
        rp.add_argument(metavar='model_dir', dest='model_dir', type=str,
                        help='location model files are at.')
        rp.add_argument(metavar='out_dir', dest='out_dir', type=str,
                        help='location to send output files to.')
        rp.add_argument('-k', metavar='checkpoint', dest='model_chkpt',
                        type=str, default="",
                        help='location of specific checkpoint file to use.  '
                        'Otherwise, defaults to best.')
        rp.add_argument('-q', '--seqmodel', metavar='seqmodel.json',
                        dest='seqmodel_json', type=str, default='',
                        help='location of seqmodel param file, if retraining '
                        'with sequence conservation')

    def add_all_config_parsers():
        cp = subparsers.add_parser(
            'config', description='config commands',
            help='interaction configuration', parents=[pp])
        cp_subparsers = cp.add_subparsers(
            title='config subcommands', metavar='SUBCOMMAND',
            help='DESCRIPTION')

        def add_all_interaction_config_parsers():
            ip = cp_subparsers.add_parser(
                'interact', description='interact config commands',
                help='interaction configuration', parents=[pp])
            ip_subparsers = ip.add_subparsers(
                title='interaction config subcommands', metavar='SUBCOMMAND',
                help='DESCRIPTION')

            def add_config_model_parser():
                import src.learning.interaction.model_params as ic
                op = ip_subparsers.add_parser(
                    'model',
                    description='Model Configuration.',
                    help='configure interaction model.',
                    parents=[pp])
                op.set_defaults(func=ic.init_params)
                op.add_argument(
                    metavar='model.json', dest='param_json', type=str,
                    help='location to store parameters at.')
                op.add_argument(
                    '-s', metavar='conv size', type=int, dest='conv_size',
                    help='size of convolutional filter to apply')
                op.add_argument(
                    '-t',
                    metavar='conv stride',
                    type=int,
                    dest='conv_stride',
                    help='size of stride to apply')
                op.add_argument('-n', action='store_true', dest='batch_norm',
                                help='apply batch normalization after layers')
                op.add_argument('-d', action='store_true', dest='dropout',
                                help='do dropout after layers')
                op.add_argument(
                    '--no-border-mode-valid',
                    action='store_false',
                    dest='border_mode_valid',
                    help='whether to use valid border mode instead of same')
                op.add_argument(
                    '-de', metavar='dense layers', dest='dense_layers',
                    type=int, help='number of dense layers to apply after '
                    'merge, if not siamese')
                op.add_argument(
                    '--rad',
                    metavar='radius',
                    dest='radius_ang',
                    type=int,
                    help='radius of grids to generate, in angstroms')
                op.add_argument(
                    '-res',
                    metavar='resolution',
                    dest='resolution',
                    type=float,
                    help='resolution of each voxel, in angstroms')
                op.add_argument(
                    '--num_filters', nargs='*', type=int, dest='num_filters',
                    help='number of filters to use at each conv level. Can'
                    ' be either a single number of one for each conv layer. '
                    'Default is [32].')
                op.add_argument(
                    '--max_pool_positions', nargs='*', type=int,
                    metavar='bool',
                    help='boolean array indicating where to put max pool '
                    'filters, should be one hot of length num convs, with each'
                    ' position indicating whether to put it after the '
                    'corresponding conv layer. Default is no max pool.')
                op.add_argument(
                    '--max_pool_sizes', nargs='*', type=int,
                    metavar='size',
                    help='int array indicating size of max pool '
                    'filters, default is len(num_filters) * [2].')
                op.add_argument(
                    '--max_pool_strides', nargs='*', type=int,
                    metavar='stride',
                    help='int array indicating strides of max pool '
                    'filters, default is len(num_filters) * [2].')
                op.add_argument(
                    '--tower_fc_nodes', nargs='*', type=int,
                    metavar='size',
                    help='int array indicating number and size of hidden nodes'
                    ' at each dense layer in one of the two legs of siamese '
                    'tower, default is [256]')
                op.add_argument(
                    '--top_fc_nodes', nargs='*', type=int,
                    metavar='size',
                    help='int array indicating number and size of hidden nodes'
                    ' at each dense layer at top of network, after towers are '
                    'concatenate, default is de * [512]')

            def add_config_seqmodel_parser():
                import src.learning.interaction.seqmodel_params as ic
                op = ip_subparsers.add_parser(
                    'seqmodel',
                    description='Sequence Model Configuration.',
                    help='configure sequence model to add to standard model.',
                    parents=[pp])
                op.set_defaults(func=ic.init_params)
                op.add_argument(metavar='struct_model_dir',
                                dest='struct_model',
                                type=str,
                                help='dir to use for base structural model')
                op.add_argument(
                    metavar='seqmodel.json', dest='param_json', type=str,
                    help='location to store parameters at.')
                op.add_argument('-n', action='store_true', dest='batch_norm',
                                help='apply batch normalization after layers')
                op.add_argument('-d', action='store_true', dest='dropout',
                                help='do dropout after layers')
                op.add_argument('--src', dest='source', nargs='+',
                                help='source of features for learning, '
                                'options are pssm, psfm, struct')
                op.add_argument(
                    '-w',
                    metavar='radius',
                    type=int,
                    dest='cons_window_radius',
                    help='window radius to use for pssm lookup')

            def add_config_test_parser():
                import src.learning.interaction.test_params as tp
                op = ip_subparsers.add_parser(
                    'test',
                    description='Testing Configuration.',
                    help='configure interaction test run.', parents=[pp])
                op.set_defaults(func=tp.init_params)
                op.add_argument(
                    'dataset_tfrecords', type=str,
                    help='tfrecords containing dataset.')
                op.add_argument(
                    'num_testing', metavar='num_testing',
                    type=int,
                    help='number of examples to test with')
                op.add_argument(
                    metavar='test.json', dest='param_json', type=str,
                    help='location to store parameters at.')
                op.add_argument(
                    '-dir',
                    metavar='num directions',
                    type=int,
                    dest='num_directions',
                    help='number of directions to apply for data augmentation')
                op.add_argument(
                    '-b',
                    metavar='batch size',
                    type=int,
                    dest='batch_size',
                    help='mini batch size')
                op.add_argument(
                    '--checknans', action='store_true', dest='check_nans',
                    help='check for nans as running')
                op.add_argument('-g', metavar='gpus', type=int, dest='towers',
                                help='number of gpus to use')
                op.add_argument(
                    '-r', metavar='num rolls', type=int, dest='num_rolls',
                    help='number of rolls to apply for data augmentation')
                group = op.add_mutually_exclusive_group()
                group.add_argument(
                    '--prune_file_testing', type=str,
                    help='txt file contain records to prune from tfrecords for'
                    ' testing')
                group.add_argument(
                    '--keep_file_testing', type=str,
                    help='txt file contain records to keep from tfrecords for '
                    'testing')
                op.add_argument('--seq_src', metavar='pssm.h5', type=str,
                                help='hdf5 file with pssms and psfms')

            def add_config_train_parser():
                import src.learning.interaction.train_params as tp
                op = ip_subparsers.add_parser(
                    'train',
                    description='Training Configuration.',
                    help='configure interaction train run.',
                    parents=[pp])
                op.set_defaults(func=tp.init_params)
                op.add_argument(
                    'dataset_tfrecords', type=str,
                    help='tfrecords containing training dataset.')
                op.add_argument(
                    'num_training',
                    type=int,
                    help='number of examples to train with')
                op.add_argument(
                    'num_validation',
                    type=int,
                    help='number of examples to validate with')
                op.add_argument(
                    metavar='train.json', dest='param_json', type=str,
                    help='location to store parameters at.')
                op.add_argument(
                    '-dir',
                    metavar='num directions',
                    type=int,
                    dest='num_directions',
                    help='number of directions to apply for data augmentation')
                op.add_argument(
                    '--val_dataset_tfrecords',
                    help='specify alternative source for validation set.')
                op.add_argument(
                    '--checknans', action='store_true', dest='check_nans',
                    help='check for nans as running')
                op.add_argument(
                    '-lr', metavar='learning rate', type=float,
                    dest='learning_rate', help='learning rate to use')
                op.add_argument(
                    '-o', metavar='optimizer', type=str, dest='optimizer',
                    help='optimizer to use to learning')
                op.add_argument('-ne', action='store_true', dest='nesterov',
                                help='use nesterov momentum for SGD')
                op.add_argument(
                    '-b',
                    metavar='batch size',
                    type=int,
                    dest='batch_size',
                    help='mini batch size')
                op.add_argument(
                    '--rolls-per-pass',
                    metavar='rolls_per_pass',
                    type=int,
                    dest='rolls_per_pass',
                    help='how many rolls to use for each example in each pass.'
                    ' e.g. if num_rolls is 20,  rolls_per_pass is 5, and '
                    'batch_size is 32 then we will perform 4 passes in each '
                    'epoch and each batch will be of size 32 * 5 = 160.')
                op.add_argument('-g', metavar='gpus', type=int, dest='towers',
                                help='number of gpus to use')
                op.add_argument(
                    '-r', metavar='num rolls', type=int, dest='num_rolls',
                    help='number of rolls to apply for data augmentation')
                op.add_argument(
                    '--shuffle_buffer', type=int,
                    help='how many examples to have in buffer as we stream '
                    'through dataset.  We only shuffle within this buffer.')
                op.add_argument(
                    '--num_interleaved', type=int,
                    help='how many tf record files to read at once.  Higher'
                    'value increases mixing of examples.')
                group = op.add_mutually_exclusive_group()
                group.add_argument(
                    '--prune_file_training', type=str,
                    help='txt file contain records to prune from tfrecords for'
                    ' training')
                group.add_argument(
                    '--keep_file_training', type=str,
                    help='txt file contain records to keep from tfrecords for '
                    'training')
                group = op.add_mutually_exclusive_group()
                group.add_argument(
                    '--prune_file_validation', type=str,
                    help='txt file contain records to prune from tfrecords for'
                    ' validation')
                group.add_argument(
                    '--keep_file_validation', type=str,
                    help='txt file contain records to keep from tfrecords for '
                    'validation')
                op.add_argument(
                    '--keep_file_pairs', type=str,
                    help='txt file contain pairs to keep from tfrecords for '
                    'training')
                op.add_argument('-e', metavar='num_epochs', type=int,
                                dest='max_epochs',
                                help='max number of epochs to run')
                op.add_argument(
                    '--stop_criteria', action='append',
                    default=[],
                    type=lambda kv: kv.split("="),
                    help='stop criteria for learning, valid options are: '
                    '1) converge=bool (stop when converged), '
                    '2) train_threshold=float, '
                    '(stop when train loss below threshold), '
                    '3) val_threshold=float, '
                    '(stop when val loss below threshold), '
                    'default is converge=True.'
                    ' 4) lr=float'
                    '(decay LR when val loss converges'
                    ' until it is below lr) ')
                op.add_argument('--lr_decay_type',
                                metavar='type',
                                dest='lr_decay_type',
                                choices=['none',
                                         'exponential',
                                         'inverse_time',
                                         'natural_exp',
                                         'cosine'],
                                default='none',
                                help='what type learning rate decay to use')
                op.add_argument('--lr_decay_steps',
                                metavar='steps',
                                type=int,
                                help='every how many steps to decay learning '
                                'rate')
                op.add_argument('--lr_decay_rate',
                                metavar='rate',
                                type=float,
                                help='what rate to decay learning rate')
                op.add_argument('--lr_decay_staircase',
                                metavar='staircase',
                                type=bool,
                                help='if True decay the learning rate at '
                                'discrete intervals',
                                default=False)
                op.add_argument('--loose', action='store_true',
                                help='allow different examples across dataset '
                                ' repeats.  Faster.')
                op.add_argument('--seq_src', metavar='pssm.h5', type=str,
                                help='hdf5 file with pssms and psfms')
            add_parsers()
        add_parsers()
    add_parsers()

    args = p.parse_args()

    if args.l is None:
        logging.basicConfig(stream=sys.stdout,
                            format='%(asctime)s %(levelname)s %(process)d: ' +
                            '%(message)s',
                            level=logging.INFO)
    else:
        log_dir = os.path.dirname(args.l)
        if len(log_dir) != 0 and not os.path.exists(log_dir):
            os.makedirs(log_dir)
        logging.basicConfig(filename=args.l,
                            format='%(asctime)s %(levelname)s %(process)d: ' +
                            '%(message)s',
                            level=logging.INFO)

    logging.info('=================== CALL ===================')
    logging.info('Host is {:}'.format(socket.gethostname()))
    logging.info('{}'.format(' '.join(sys.argv)))
    args.func(args)
    logging.info('================= END CALL =================')


if __name__ == "__main__":
    main()
