
"""Tests for lattice estimators."""
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_lattice as tfl
from sklearn.model_selection import KFold
import itertools

train_evaluations = []
test_evaluations = []
folds = 3

for fold in range(folds):

    train_data_path = 'boston/'+str(fold)+'/train_data.csv'
    test_data_path  = 'boston/'+str(fold)+'/test_data.csv'
    train_dataset = pd.read_csv(train_data_path,index_col=0)
    test_dataset = pd.read_csv(test_data_path,index_col=0)

    # Example training and testing data.
    train_features = {
        'CRIM': np.array(train_dataset['CRIM']),
        'ZN': np.array(train_dataset['ZN']),
        'INDUS': np.array(train_dataset['INDUS']),
        'CHAS': np.array(train_dataset['CHAS']),
        'NOX':np.array(train_dataset['NOX']),
        'RM': np.array(train_dataset['RM']),
        'AGE':np.array(train_dataset['AGE']),
        'DIS':np.array(train_dataset['DIS']),
        'RAD':np.array(train_dataset['RAD']),
        'TAX':np.array(train_dataset['TAX']),
        'PTRATIO':np.array(train_dataset['PTRATIO']),
        'B':np.array(train_dataset['B']),
        'LSTAT':np.array(train_dataset['LSTAT']),
    }

    test_features = {
        'CRIM': np.array(test_dataset['CRIM']),
        'ZN': np.array(test_dataset['ZN']),
        'INDUS': np.array(test_dataset['INDUS']),
        'CHAS': np.array(test_dataset['CHAS']),
        'NOX':np.array(test_dataset['NOX']),
        'RM': np.array(test_dataset['RM']),
        'AGE':np.array(test_dataset['AGE']),
        'DIS':np.array(test_dataset['DIS']),
        'RAD':np.array(test_dataset['RAD']),
        'TAX':np.array(test_dataset['TAX']),
        'PTRATIO':np.array(test_dataset['PTRATIO']),
        'B':np.array(test_dataset['B']),
        'LSTAT':np.array(test_dataset['LSTAT']),
    }

    train_labels = np.array(train_dataset['CMEDV'])
    test_labels =  np.array(test_dataset['CMEDV'])


    # Feature definition.
    feature_columns = [
        tf.feature_column.numeric_column('CRIM'),
        tf.feature_column.numeric_column('ZN'),
        tf.feature_column.numeric_column('INDUS'),
        tf.feature_column.numeric_column('CHAS'),
        tf.feature_column.numeric_column('NOX'),
        tf.feature_column.numeric_column('RM'),
        tf.feature_column.numeric_column('AGE'),
        tf.feature_column.numeric_column('DIS'),
        tf.feature_column.numeric_column('RAD'),
        tf.feature_column.numeric_column('TAX'),
        tf.feature_column.numeric_column('PTRATIO'),
        tf.feature_column.numeric_column('B'),
        tf.feature_column.numeric_column('LSTAT'),
    ]

    # Grid search
    key_points = [10, 50, 100]
    rates = [0.1, 0.01, 0.001]
    batch_sizes = [16, 32, 64,128,256,512,1024]
    epochs = [50,100, 400,500,800, 1000]
    
    kf = KFold(n_splits = 5, shuffle = True, random_state = 2)
    best_parameters = [0,0,0,0]
    best_evaluations = 1000
    index = 0
    for k,r,b,e in itertools.product(key_points,rates, batch_sizes, epochs):
        index+=1
        if index<=5:
            evaluations = []
            train_X = {}
            for train_index, test_index in kf.split(train_features):
                train_X = {}
                for key,values in train_features.items():
                    train_X[key] = train_features[key][train_index]
                train_Y = train_labels[train_index]
                test_X =  {}
                for key,values in train_features.items():
                    test_X[key] = train_features[key][test_index]
                test_Y =  train_labels[test_index]

                # Hyperparameters.
                num_keypoints = k
                hparams = tfl.CalibratedLatticeHParams(
                    feature_names=['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','LSTAT'],
                    num_keypoints=num_keypoints,
                    learning_rate=r,
                )
                # Define keypoint init.
                keypoints_init_fns = {
                    'CRIM': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=100.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'ZN': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=100.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'INDUS': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=50.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'CHAS': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=1.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0), 
                    'NOX': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=1.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'RM': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=10.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0), 
                    'AGE': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=110.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),                                                     
                    'DIS': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=50.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'RAD': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=30.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),   
                    'TAX': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=800.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0), 
                    'PTRATIO': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=50.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'B': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=400.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),   
                    'LSTAT': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=50.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),                                                       
                                                                                
                }

                # Set feature monotonicity.
                hparams.set_feature_param('CRIM', 'monotonicity', -1)
                hparams.set_feature_param('RM', 'monotonicity', +1)
                hparams.set_feature_param('LSTAT', 'monotonicity', -1)

                lattice_estimator = tfl.calibrated_lattice_regressor(
                    feature_columns=feature_columns,
                    hparams=hparams,
                    keypoints_initializers_fn=keypoints_init_fns)

                # Train-grid
                train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
                    x=train_X,
                    y=train_Y,
                    batch_size=b,
                    num_epochs=e,
                    shuffle=False)

                lattice_estimator.train(input_fn=train_input_fn)
                # Test-grid
                test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(x=test_X, y=test_Y, batch_size=b, num_epochs=e, shuffle=False)
                evaluation = lattice_estimator.evaluate(input_fn=test_input_fn)
                print(evaluation)
                evaluations.append(evaluation['average_loss'])
        result = np.mean(evaluations)
        if result<best_evaluations:
            best_evaluations = result
            best_parameters = [k,r,b,e]
    
    
    #best_parameters = [10, 0.1, 32, 70]
    print(best_parameters)
    [k,r,b,e] = best_parameters

    keypoints_init_fns = {
                'CRIM': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=100.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),
                'ZN': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=100.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),
                'INDUS': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=50.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),
                'CHAS': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=1.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0), 
                'NOX': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=1.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),
                'RM': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=10.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0), 
                'AGE': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=110.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),                                                     
                'DIS': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=50.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),
                'RAD': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=30.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),   
                'TAX': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=800.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0), 
                'PTRATIO': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=50.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),
                'B': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=400.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),   
                'LSTAT': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                    input_min=0.0,
                                                                    input_max=50.0,
                                                                    output_min=0.0,
                                                                    output_max=1.0),                                                       
                                                                            
            }
    # Hyperparameters.
    num_keypoints = k
    hparams = tfl.CalibratedLatticeHParams(
        feature_names=['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','LSTAT'],
        num_keypoints=num_keypoints,
        learning_rate=r,
    )

    # Set feature monotonicity.
    hparams.set_feature_param('CRIM', 'monotonicity', -1)
    hparams.set_feature_param('RM', 'monotonicity', +1)
    hparams.set_feature_param('LSTAT', 'monotonicity', -1)


    lattice_estimator = tfl.calibrated_lattice_regressor(
        feature_columns=feature_columns,
        hparams=hparams,
        keypoints_initializers_fn=keypoints_init_fns)


    # Train-grid
    train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        x=train_features,
        y=train_labels,
        batch_size=b,
        num_epochs=e,
        shuffle=False)

    lattice_estimator.train(input_fn=train_input_fn)
            
    # Test
    test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        x=test_features, y=test_labels, batch_size=b, num_epochs=e, shuffle=False)

    # Estimate
    train_estimate = lattice_estimator.evaluate(input_fn=train_input_fn)
    test_estimate = lattice_estimator.evaluate(input_fn=test_input_fn)
    train_evaluations.append(train_estimate['average_loss'])
    test_evaluations.append(test_estimate['average_loss'])  


mean_train = np.mean(train_evaluations)
print('mean_train'+str(mean_train))
std_train = np.std(train_evaluations)
print('std_train'+str(std_train))
mean_test = np.mean(test_evaluations)
print('mean_test'+str(mean_test))
std_test = np.std(test_evaluations)
print('std_test'+str(std_test))

