import argparse
import pandas as pd
import csv
import numpy as np
import json
import keras
import sys
sys.path.insert(0, '../graph_methods')
sys.path.insert(0, '../src')
from itertools import groupby
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.layers.normalization import BatchNormalization
from keras.optimizers import SGD, Adam
from CallBacks import KeckCallBackOnROC, KeckCallBackOnPrecision
from function import read_merged_data, extract_feature_and_label, reshape_data_into_2_dim
from util import output_regression_result


def rmse(X, Y):
    print('mse: {}'.format(np.mean((X-Y)**2)))
    return np.sqrt(np.mean((X - Y)**2))


def get_sample_weight(task, y_data):
    if task.weight_schema == 'no_weight':
        sw = [1.0 for _ in y_data]
    elif task.weight_schema == 'weighted_sample':
        values = set(map(lambda x: int(x), y_data))
        values = dict.fromkeys(values, 0)

        data = sorted(y_data)
        for k,g in groupby(data, key=lambda x: int(x[0])):
            temp_group = [t[0] for t in g]
            values[k] = len(temp_group)
        sum_ = reduce(lambda x, y: x + y, values.values())
        sw = map(lambda x: 1.0 * sum_ / values[int(x[0])], y_data)
    else:
        raise ValueError('Weight schema not included. Should be among [{}, {}].'.
                         format('no_weight', 'weighted_sample'))
    sw = np.array(sw)
    # Only accept 1D sample weights
    # sw = reshape_data_into_2_dim(sw)
    return sw


class SingleRegression:
    def __init__(self, conf):
        self.conf = conf
        self.input_layer_dimension = 512
        self.output_layer_dimension = 1

        self.early_stopping_patience = conf['fitting']['early_stopping']['patience']
        self.early_stopping_option = conf['fitting']['early_stopping']['option']

        self.fit_nb_epoch = conf['fitting']['nb_epoch']
        self.fit_batch_size = conf['fitting']['batch_size']
        self.fit_verbose = conf['fitting']['verbose']

        self.compile_loss = conf['compile']['loss']
        self.compile_optimizer_option = conf['compile']['optimizer']['option']
        if self.compile_optimizer_option == 'sgd':
            sgd_lr = conf['compile']['optimizer']['sgd']['lr']
            sgd_momentum = conf['compile']['optimizer']['sgd']['momentum']
            sgd_decay = conf['compile']['optimizer']['sgd']['decay']
            sgd_nestrov = conf['compile']['optimizer']['sgd']['nestrov']
            self.compile_optimizer = SGD(lr=sgd_lr, momentum=sgd_momentum, decay=sgd_decay, nesterov=sgd_nestrov)
        else:
            adam_lr = conf['compile']['optimizer']['adam']['lr']
            adam_beta_1 = conf['compile']['optimizer']['adam']['beta_1']
            adam_beta_2 = conf['compile']['optimizer']['adam']['beta_2']
            adam_epsilon = conf['compile']['optimizer']['adam']['epsilon']
            self.compile_optimizer = Adam(lr=adam_lr, beta_1=adam_beta_1, beta_2=adam_beta_2, epsilon=adam_epsilon)

        self.batch_is_use = conf['batch']['is_use']
        if self.batch_is_use:
            batch_normalizer_epsilon = conf['batch']['epsilon']
            batch_normalizer_mode = conf['batch']['mode']
            batch_normalizer_axis = conf['batch']['axis']
            batch_normalizer_momentum = conf['batch']['momentum']
            batch_normalizer_weights = conf['batch']['weights']
            batch_normalizer_beta_init = conf['batch']['beta_init']
            batch_normalizer_gamma_init = conf['batch']['gamma_init']
            self.batch_normalizer = BatchNormalization(epsilon=batch_normalizer_epsilon,
                                                       mode=batch_normalizer_mode,
                                                       axis=batch_normalizer_axis,
                                                       momentum=batch_normalizer_momentum,
                                                       weights=batch_normalizer_weights,
                                                       beta_init=batch_normalizer_beta_init,
                                                       gamma_init=batch_normalizer_gamma_init)
        self.weight_schema = conf['sample_weight_option']

        if 'hit_ratio' in self.conf.keys():
            self.hit_ratio = conf['hit_ratio']
        else:
            self.hit_ratio = 0.01
        return

    def setup_model(self):
        model = Sequential()
        layers = self.conf['layers']
        layer_number = len(layers)
        for i in range(layer_number):
            init = layers[i]['init']
            activation = layers[i]['activation']
            if i == 0:
                hidden_units = int(layers[i]['hidden_units'])
                dropout = float(layers[i]['dropout'])
                model.add(Dense(hidden_units, input_dim=self.input_layer_dimension, init=init, activation=activation))
                # model.add(Dropout(dropout))
            elif i == layer_number - 1:
                if self.batch_is_use:
                    model.add(self.batch_normalizer)
                model.add(Dense(self.output_layer_dimension, init=init, activation=activation))
            else:
                hidden_units = int(layers[i]['hidden_units'])
                dropout = float(layers[i]['dropout'])
                model.add(Dense(hidden_units, init=init, activation=activation))
                # model.add(Dropout(dropout))

        return model

    def train_and_predict(self, X_train, y_train, X_test, y_test, weight_file):
        model = self.setup_model()
        print model.summary()

        model.compile(loss=self.compile_loss, optimizer=self.compile_optimizer)
        model.fit(x=X_train, y=y_train,
                  nb_epoch=self.fit_nb_epoch,
                  batch_size=self.fit_batch_size,
                  verbose=self.fit_verbose,
                  shuffle=True)

        y_pred_on_train = reshape_data_into_2_dim(model.predict(X_train))
        rmse_train = rmse(y_pred_on_train, y_train)
        print('RMSE on train set: {}'.format(rmse_train))
        if X_test is not None:
            y_pred_on_test = reshape_data_into_2_dim(model.predict(X_test))
            rmse_test = rmse(y_pred_on_test, y_test)
            print('RMSE on test set: {}'.format(rmse_test))

        return

    def save_model(self, model, weight_file):
        model.save_weights(weight_file)
        return

    def load_model(self, weight_file):
        model = self.setup_model()
        model.load_weights(weight_file)
        return model


def demo_single_regression():
    conf = {
        'layers': [
            {
                'hidden_units': 2000,
                'init': 'glorot_normal',
                'activation': 'relu',
                'dropout': 0.25
            }, {
                'hidden_units': 2000,
                'init': 'glorot_normal',
                'activation': 'relu',
                'dropout': 0.25
            }, {
                'init': 'glorot_normal',
                'activation': 'linear'
            }
        ],
        'compile': {
            'loss': 'mse',
            'optimizer': {
                'option': 'adam',
                'sgd': {
                    'lr': 0.003,
                    'momentum': 0.9,
                    'decay': 0.9,
                    'nestrov': True
                },
                'adam': {
                    'lr': 0.001,
                    'beta_1': 0.9,
                    'beta_2': 0.999,
                    'epsilon': 1e-8
                }
            }
        },
        'fitting': {
            'nb_epoch': 100,
            'batch_size': 100,
            'verbose': 1,
            'early_stopping': {
                'option': 'auc',
                'patience': 50
            }
        },
        'batch': {
            'is_use': True,
            'epsilon': 2e-5,
            'mode': 0,
            'axis': -1,
            'momentum': 0.9,
            'weights': None,
            'beta_init': 'zero',
            'gamma_init': 'one'
        },
        'sample_weight_option': 'no_weight',
        'label_name_list': ['delaney']
    }
    label_name_list = conf['label_name_list']
    print 'label_name_list ', label_name_list

    train_pd = read_merged_data(file_list[1:5])
    train_pd.fillna(0, inplace=True)
    test_pd = read_merged_data(file_list[0:1])
    test_pd.fillna(0, inplace=True)

    # extract data, and split training data into training and val
    X_train, y_train = extract_feature_and_label(train_pd,
                                                 feature_name='Fingerprints',
                                                 label_name_list=label_name_list)
    X_test, y_test = extract_feature_and_label(test_pd,
                                               feature_name='Fingerprints',
                                               label_name_list=label_name_list)
    print 'done data preparation'

    task = SingleRegression(conf=conf)
    task.train_and_predict(X_train, y_train, X_test, y_test, weight_file)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weight_file', action='store', dest='weight_file', required=True)
    parser.add_argument('--mode', action='store', dest='mode', required=False, default='single_classification')
    given_args = parser.parse_args()
    weight_file = given_args.weight_file
    mode = given_args.mode

    # specify dataset
    K = 5
    directory = '../datasets/delaney/{}.csv.gz'
    file_list = []
    for i in range(K):
        file_list.append(directory.format(i))

    demo_single_regression()
