
"""Tests for lattice estimators."""
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow_lattice as tfl
from sklearn.model_selection import KFold

import itertools

column_names = ['age','sex','cp','trestbps','chol','fbs','restecg','thalach','exang','oldpeak','slope','ca','thal'] 


train_evaluations = []
test_evaluations = []
folds = 3

for fold in range(folds):
    train_data_path = 'heart/'+str(fold)+'/train_data.csv'
    test_data_path  = 'heart/'+str(fold)+'/test_data.csv'
    train_dataset = pd.read_csv(train_data_path,index_col=0)
    test_dataset = pd.read_csv(test_data_path,index_col=0)

    # Example training and testing data.
    train_features = {
        'age': np.array(train_dataset['age']),
        'sex': np.array(train_dataset['sex']),
        'cp': np.array(train_dataset['cp']),
        'trestbps': np.array(train_dataset['trestbps']),
        'chol':np.array(train_dataset['chol']),
        'fbs': np.array(train_dataset['fbs']),
        'restecg':np.array(train_dataset['restecg']),
        'thalach':np.array(train_dataset['thalach']),
        'exang':np.array(train_dataset['exang']),
        'oldpeak':np.array(train_dataset['oldpeak']),
        'slope':np.array(train_dataset['slope']),
        'ca':np.array(train_dataset['ca']),
        'thal':np.array(train_dataset['thal']),
    }

    test_features = {
        'age': np.array(test_dataset['age']),
        'sex': np.array(test_dataset['sex']),
        'cp': np.array(test_dataset['cp']),
        'trestbps': np.array(test_dataset['trestbps']),
        'chol':np.array(test_dataset['chol']),
        'fbs': np.array(test_dataset['fbs']),
        'restecg':np.array(test_dataset['restecg']),
        'thalach':np.array(test_dataset['thalach']),
        'exang':np.array(test_dataset['exang']),
        'oldpeak':np.array(test_dataset['oldpeak']),
        'slope':np.array(test_dataset['slope']),
        'ca':np.array(test_dataset['ca']),
        'thal':np.array(test_dataset['thal']),
    }

    train_labels = np.array(train_dataset['target'])
    test_labels =  np.array(test_dataset['target'])


    # Feature definition.
    feature_columns = [
        tf.feature_column.numeric_column('age'),
        tf.feature_column.numeric_column('sex'),
        tf.feature_column.numeric_column('cp'),
        tf.feature_column.numeric_column('trestbps'),
        tf.feature_column.numeric_column('chol'),
        tf.feature_column.numeric_column('fbs'),
        tf.feature_column.numeric_column('restecg'),
        tf.feature_column.numeric_column('thalach'),
        tf.feature_column.numeric_column('exang'),
        tf.feature_column.numeric_column('oldpeak'),
        tf.feature_column.numeric_column('slope'),
        tf.feature_column.numeric_column('ca'),
        tf.feature_column.numeric_column('thal'),
    ]
    # Grid search
    key_points = [10, 50, 100]
    rates = [0.1, 0.01, 0.001]
    batch_sizes = [16, 32, 64,128,256,512,1024]
    epochs = [50,100, 400,500,800, 1000]
    
    kf = KFold(n_splits = 5, shuffle = True, random_state = 2)
    best_parameters = [0,0,0,0]
    best_evaluations = 100
    index=0
    for k,r,b,e in itertools.product(key_points,rates, batch_sizes, epochs):
        index+=1
        if index<=5:
            evaluations = []
            train_X = {}
            for train_index, test_index in kf.split(train_features):
                train_X = {}
                for key,values in train_features.items():
                    train_X[key] = train_features[key][train_index]
                train_Y = train_labels[train_index]
                test_X =  {}
                for key,values in train_features.items():
                    test_X[key] = train_features[key][test_index]
                test_Y =  train_labels[test_index]


                # Hyperparameters.
                num_keypoints = k
                hparams = tfl.CalibratedLatticeHParams(
                    feature_names=['age','sex','cp','trestbps','chol','fbs','restecg','thalach','exang','oldpeak','slope','ca','thal'],
                    num_keypoints=num_keypoints,
                    learning_rate=r,
                )

                # Set feature monotonicity.
                hparams.set_feature_param('chol', 'monotonicity', +1)
                hparams.set_feature_param('trestbps', 'monotonicity', +1)


                # Define keypoint init.
                keypoints_init_fns = {
                    'age': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=100.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'sex': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=1.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'cp': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=5.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'trestbps': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=200.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0), 
                    'chol': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=600.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'fbs': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=1.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),
                    'restecg': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=3.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),    
                    'thalach': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=200.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),   
                    'exang': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=1.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),    
                    'oldpeak': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=10.0,
                                                                        input_max=0.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),  
                    'slope': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=1.0,
                                                                        input_max=3.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),    
                    'ca': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=3.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0), 
                    'thal': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                                        input_min=0.0,
                                                                        input_max=4.0,
                                                                        output_min=0.0,
                                                                        output_max=1.0),    
                }


                lattice_estimator = tfl.calibrated_lattice_regressor(
                    feature_columns=feature_columns,
                    hparams=hparams,
                    keypoints_initializers_fn=keypoints_init_fns)


                # Train-grid
                train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
                    x=train_X,
                    y=train_Y,
                    batch_size=b,
                    num_epochs=e,
                    shuffle=False)

                lattice_estimator.train(input_fn=train_input_fn)
                # Test-grid
                test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(x=test_X, y=test_Y, batch_size=b, num_epochs=e, shuffle=False)
                evaluation = lattice_estimator.evaluate(input_fn=test_input_fn)
                evaluations.append(evaluation['average_loss'])
    
        result = np.mean(evaluations)
        if result<best_evaluations:
            best_evaluations = result
            best_parameters = [k,r,b,e]
    
    print(best_parameters)
    [k,r,b,e] = best_parameters

    # Hyperparameters.
    num_keypoints = k
    hparams = tfl.CalibratedLatticeHParams(
        feature_names=['age','sex','cp','trestbps','chol','fbs','restecg','thalach','exang','oldpeak','slope','ca','thal'],
        num_keypoints=num_keypoints,
        non_monotonic_num_lattices=1,
        non_monotonic_lattice_rank=1,
        non_monotonic_lattice_size=2,
        learning_rate=r,
    )

    # Set feature monotonicity.
    hparams.set_feature_param('chol', 'monotonicity', +1)
    hparams.set_feature_param('trestbps', 'monotonicity', +1)


    # Define keypoint init.
    keypoints_init_fns = {
        'age': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=100.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'sex': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=1.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'cp': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=5.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'trestbps': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=200.0,
                                                            output_min=0.0,
                                                            output_max=1.0), 
        'chol': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=600.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'fbs': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=1.0,
                                                            output_min=0.0,
                                                            output_max=1.0),
        'restecg': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=3.0,
                                                            output_min=0.0,
                                                            output_max=1.0),    
        'thalach': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=200.0,
                                                            output_min=0.0,
                                                            output_max=1.0),   
        'exang': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=1.0,
                                                            output_min=0.0,
                                                            output_max=1.0),    
        'oldpeak': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=10.0,
                                                            input_max=0.0,
                                                            output_min=0.0,
                                                            output_max=1.0),  
        'slope': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=1.0,
                                                            input_max=3.0,
                                                            output_min=0.0,
                                                            output_max=1.0),    
        'ca': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=3.0,
                                                            output_min=0.0,
                                                            output_max=1.0), 
        'thal': lambda: tfl.uniform_keypoints_for_signal(num_keypoints,
                                                            input_min=0.0,
                                                            input_max=4.0,
                                                            output_min=0.0,
                                                            output_max=1.0),    
    }
    #config = tf.estimator.RunConfig(tf_random_seed=1)

    lattice_estimator = tfl.calibrated_linear_classifier(
        feature_columns=feature_columns,
        hparams=hparams,
        keypoints_initializers_fn=keypoints_init_fns)

     # Train-grid
    train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        x=train_features,
        y=train_labels,
        batch_size=b,
        num_epochs=e,
        shuffle=False)
  
    lattice_estimator.train(input_fn=train_input_fn)
            
    # Test
    test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
        x=test_features, y=test_labels, batch_size=b, num_epochs=e, shuffle=False)

    # Estimate
    train_estimate = lattice_estimator.evaluate(input_fn=train_input_fn)
    test_estimate = lattice_estimator.evaluate(input_fn=test_input_fn)
    train_evaluations.append(train_estimate['accuracy'])
    test_evaluations.append(test_estimate['accuracy'])  


mean_train = np.mean(train_evaluations)
print('mean_train: '+str(mean_train))
std_train = np.std(train_evaluations)
print('std_train: '+str(std_train))
mean_test = np.mean(test_evaluations)
print('mean_test: '+str(mean_test))
std_test = np.std(test_evaluations)
print('std_test: '+str(std_test))
