import numpy as np
setting_name = 'bias'
root = ''

'''
Example hp for each model. 

RF - {'n_estimators': [200,400], \
    'min_samples_split': [2,10], \
    'max_depth': [5,10,20],
    'max_features': ['auto']}
    
Lasso - {'alpha': [1.,0.1, 0.01, 0.001]}

LR - {}

MLP -  {'hidden_layer_sizes': [(100,),(50,50)],
        'activation': ['relu'],
        'solver': ['adam'],
        'alpha': [1.,.01,.001,.0001],
        'learning_rate': ['adaptive'],
        'learning_rate_init': [1e-3],
        'max_iter': [500]}
'''

if setting_name == 'demo':
    save_folder_name = 'demo_updated'
    num_iters = 1
    alpha = 0.05
    strata_mod = ''
    strata_metadata_mod = ''
    
    params_mod = [[ ('obs_dict',{
                    'num_obs': 5,
                    'sizes': [5.,5.,5.,5.,5.,],
                    'confounder_concealment': [0,0,2,4,6],
                    'missing_bias': [False,False,False,False,False]
                 }),
                  ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(50,50)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [200]}
                  }) ]]
    
#     [ ('obs_dict',{
#                     'num_obs': 1,
#                     'sizes': [1.],
#                     'confounder_concealment': [3],
#                     'missing_bias': [False]
#                  }),
#                   ('response_surface', {
#                             'ctr': 'non_linear', 
#                             'trt': 'linear',
#                             'model': 'LinearRegression',
#                             'hp': {}
#                   }) ]

if setting_name == 'upsize': 
    save_folder_name = 'upsizemlp_full_reweight_correction'
    num_iters = 50
    alpha = 0.05 
    strata_mod = ''
    strata_metadata_mod = ''
    
    
    params_mod = [
                    [
                        ('obs_dict',{
                            'num_obs': 5,
                            'sizes': [1.,1.,1.,1.,1.],
                            'confounder_concealment': [0,0,2,4,6],
                            'missing_bias': [False,False,False,False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(25,25)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [250]}
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 5,
                            'sizes': [3.,3.,3.,3.,3.],
                            'confounder_concealment': [0,0,2,4,6],
                            'missing_bias': [False,False,False,False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(25,25)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [250]}
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 5,
                            'sizes': [5.,5.,5.,5.,5.],
                            'confounder_concealment': [0,0,2,4,6],
                            'missing_bias': [False,False,False,False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(25,25)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [250]}
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 5,
                            'sizes': [10,10,10,10,10],
                            'confounder_concealment': [0,0,2,4,6],
                            'missing_bias': [False,False,False,False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(25,25)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [250]}
                        })
                    ]
                 ]
    
if setting_name == 'bias':
    save_folder_name = 'biasmlp_full_reweighted_correction'
    num_iters = 100
    alpha = 0.05
    strata_mod = ''
    strata_metadata_mod = ''
    
    params_mod = [
                    [
                        ('obs_dict',{
                            'num_obs': 5,
                            'sizes': [5.,5.,5.,5.,5.],
                            'confounder_concealment': [0,0,0,0,0],
                            'missing_bias': [False,False,False,False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(25,25)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [250]}
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 5,
                            'sizes': [5.,5.,5.,5.,5.],
                            'confounder_concealment': [0,0,0,0,3],
                            'missing_bias': [False,False,False,False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(25,25)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [250]}
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 5,
                            'sizes': [5.,5.,5.,5.,5.],
                            'confounder_concealment': [0,0,0,3,3],
                            'missing_bias': [False,False,False,False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(25,25)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [250]}
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 5,
                            'sizes': [5.,5.,5.,5.,5.],
                            'confounder_concealment': [0,3,3,3,3],
                            'missing_bias': [False,False,False,False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'MLP',
                            'hp': {'hidden_layer_sizes': [(25,25)],
                                    'activation': ['relu'],
                                    'solver': ['adam'],
                                    'alpha': [.0001],
                                    'learning_rate': ['adaptive'],
                                    'learning_rate_init': [1e-3],
                                    'max_iter': [250]}
                        })
                    ]
                 ]
    

if setting_name == 'test':
    save_folder_name = 'biastest'
    num_iters = 1
    alpha = 0.05
    strata_mod = ''
    strata_metadata_mod = ''
    
    params_mod = [
                    [
                        ('obs_dict',{
                            'num_obs': 2,
                            'sizes': [1.,1.],
                            'confounder_concealment': [0,0],
                            'missing_bias': [False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'LinearRegression',
                            'hp': {}
                            
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 2,
                            'sizes': [1.,1.],
                            'confounder_concealment': [0,3],
                            'missing_bias': [False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'LinearRegression',
                            'hp': {}
                            
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 2,
                            'sizes': [1.,1.],
                            'confounder_concealment': [0,3],
                            'missing_bias': [False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'LinearRegression',
                            'hp': {}
                            
                        })
                    ],
                    [
                        ('obs_dict',{
                            'num_obs': 2,
                            'sizes': [1.,1.],
                            'confounder_concealment': [0,3],
                            'missing_bias': [False,False]
                        }),
                        ('response_surface', {
                            'ctr': 'non_linear', 
                            'trt': 'linear',
                            'model': 'LinearRegression',
                            'hp': {}
    
                        })
                    ]
                 ]    
    
    

if setting_name == 'confound':
    save_folder_name = 'confound'
    num_iters = 100
    alpha = 0.05
    strata_mod = ''
    strata_metadata_mod = ''
    
    params_mod = [(('beta_seed', 4), \
                   ('gamma_coefs', 0.*np.array([0.1,0.2,0.5,0.75,1])), \
                   ('sizes', [2.5,2.5,2.5,2.5,2.5])), \
                  (('beta_seed', 4), \
                   ('gamma_coefs', 0.01*np.array([0.1,0.2,0.5,0.75,1])), \
                   ('sizes', [2.5,2.5,2.5,2.5,2.5])), \
                  (('beta_seed', 4), \
                   ('gamma_coefs', 0.1*np.array([0.1,0.2,0.5,0.75,1])), \
                   ('sizes', [2.5,2.5,2.5,2.5,2.5])), \
                  (('beta_seed', 4), \
                   ('gamma_coefs', 0.5*np.array([0.1,0.2,0.5,0.75,1])), \
                   ('sizes', [2.5,2.5,2.5,2.5,2.5])), \
                  (('beta_seed', 4), \
                   ('gamma_coefs', np.array([0.1,0.2,0.5,0.75,1])), \
                   ('sizes', [2.5,2.5,2.5,2.5,2.5]))]
