import torch
from gflownet.tasks.seh_double import SEHDoubleModelTrainer
from gflownet.tasks.qm9.qm9_double import QM9MixtureModelTrainer
from gflownet.tasks.qm9.qm9 import QM9GapTrainer

log_root = '../jobs/test'

base_hps = {
    'log_dir': log_root,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    'overwrite_existing_exp': True,
    'num_training_steps': 20000,
    'validate_every': 0,
    'num_workers': 0,
    'opt': {
        'lr_decay': 20000
    },
    'algo': {
        'p_greedy_sample': False,
        'p_of_max_sample': True,
        'p_quantile_sample': False,
        'scale_temp': False,
        'p': 0.8,
        'dqn_n_step': 3,
        'sampling_tau': 0.99,
        'global_batch_size': 64,
        'sampling_ratio': 1.0,
        'ddqn_update_step': 1,
        'rl_train_random_action_prob': 0.01,
        'dqn_tau': 0.9
    },
    'cond': {
        'temperature': {
            'sample_dist': 'uniform',
            'dist_params': [2.5, 32],
            'num_thermometer_dim': 1,
        }
    },
    'replay': {
        'use': False,
        'capacity': 100,
        'warmup': 0,
        'method': 'Random'
    },
    'task': {
        'qm9': {
            'h5_path': 'path.to.dataset/qm9.h5',
            'model_path': 'path.to.model/mxmnet_gap_model.pt'
        }
    }
}

trial = SEHDoubleModelTrainer(base_hps)
trial.print_every = 1
trial.run()

# import pdb
# pdb.set_trace()