import torch
import argparse
import time
from parsers.parser import Parser
from parsers.config import get_config

def main(work_type_args):
    args = Parser().parse()
    ts = time.strftime('%b%d-%H:%M:%S', time.gmtime())
    config = get_config(args.config, args.seed)    
    config.method = args.method
    if args.method == 'dpm':
        from method_series.dpm_trainer import Trainer
        trainer = Trainer(config) 
        trainer.train(ts)
    if args.method == 'gmnn':
        from method_series.gmnn_trainer import Trainer
        trainer = Trainer(config) 
        trainer.train(ts)
    if args.method == 'clgnn':
        from method_series.clgnn_trainer import Trainer
        trainer = Trainer(config) 
        trainer.train(ts)
    if args.method == 'g3nn':
        from method_series.g3nn_trainer import Trainer
        trainer = Trainer(config) 
        trainer.train(ts)
    if args.method == 'lpa':
        from method_series.lpa_trainer import Trainer
        trainer = Trainer(config) 
        trainer.train(ts)
    if args.method == 'base':
        from method_series.base_trainer import Trainer
        trainer = Trainer(config) 
        trainer.train(ts)

if __name__ == '__main__':
    work_type_parser = argparse.ArgumentParser()
    main(work_type_parser.parse_known_args()[0])
