"""
Tarnet
This is copied directly from ACIC_ATT, with minor modifications.
"""

import os
import glob
import argparse
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

from keras.engine.topology import Layer
import keras.backend as K
from keras.optimizers import rmsprop, SGD, Adam
from keras.layers import Input, Dense, Concatenate, BatchNormalization, Dropout
from keras.models import Model
from keras import regularizers
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau, TerminateOnNaN
from keras.metrics import binary_accuracy
from ACIC_ATE.idhp_data import *


def binary_classification_loss(concat_true, concat_pred):
    t_true = concat_true[:, 1]
    t_pred = concat_pred[:, 2]
    t_pred = (t_pred + 0.001) / 1.002

    losst = tf.reduce_sum(K.binary_crossentropy(t_true, t_pred))

    return losst


def regression_loss(concat_true, concat_pred):
    y_true = concat_true[:, 0]
    t_true = concat_true[:, 1]

    y0_pred = concat_pred[:, 0]
    y1_pred = concat_pred[:, 1]

    loss0 = tf.reduce_sum((1. - t_true) * tf.square(y_true - y0_pred))
    loss1 = tf.reduce_sum(t_true * tf.square(y_true - y1_pred))

    return loss0 + loss1


def dragonnet_loss_binarycross(concat_true, concat_pred):
    print(" this is default dragonloss")
    return regression_loss(concat_true, concat_pred) + binary_classification_loss(concat_true, concat_pred)



def treatment_accuracy(concat_true, concat_pred):
    t_true = concat_true[:, 1]
    t_pred = concat_pred[:, 2]
    return binary_accuracy(t_true, t_pred)


def monitor_epsilon_hack(concat_true, concat_pred):
    epsilons = concat_pred[:, 3]
    return tf.reduce_mean(epsilons)


def track_epsilon(concat_true, concat_pred):
    epsilons = concat_pred[:, 3]
    return tf.abs(tf.reduce_mean(epsilons))


class EpsilonLayer(Layer):

    def __init__(self):
        super(EpsilonLayer, self).__init__()

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.epsilon = self.add_weight(name='epsilon',
                                       shape=[1, 1],
                                       initializer='RandomNormal',
                                       #  initializer='ones',
                                       trainable=True)
        super(EpsilonLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, inputs, **kwargs):
        # import ipdb; ipdb.set_trace()
        return self.epsilon * tf.ones_like(inputs)[:, 0:1]


def make_tarreg_loss(ratio=1., dragonnet_loss=dragonnet_loss_binarycross):
    def tarreg_ATE_unbounded_domain_loss(concat_true, concat_pred):
        vanilla_loss = dragonnet_loss(concat_true, concat_pred)

        y_true = concat_true[:, 0]
        t_true = concat_true[:, 1]

        y0_pred = concat_pred[:, 0]
        y1_pred = concat_pred[:, 1]
        t_pred = concat_pred[:, 2]

        epsilons = concat_pred[:, 3]
        t_pred = (t_pred + 0.001) / 1.002
        # t_pred = tf.clip_by_value(t_pred,0.01, 0.99,name='t_pred')

        y_pred = t_true * y1_pred + (1 - t_true) * y0_pred

        h = t_true / t_pred - (1 - t_true) / (1 - t_pred)

        y_pert = y_pred + epsilons * h
        targeted_regularization = tf.reduce_sum(tf.square(y_true - y_pert))

        # final
        loss = vanilla_loss + ratio * targeted_regularization
        return loss

    return tarreg_ATE_unbounded_domain_loss


def make_dragonnet1(input_dim, reg_l2):
    """
    Neural net predictive model. The dragon has three heads.
    :param input_dim:
    :param reg:
    :return:
    """
    t_l1 = 0.
    t_l2 = reg_l2
    inputs = Input(shape=(input_dim,), name='input')

    # representation
    x = Dense(units=200, activation='elu', kernel_initializer='RandomNormal')(inputs)
    x = Dense(units=200, activation='elu', kernel_initializer='RandomNormal')(x)
    x = Dense(units=200, activation='elu', kernel_initializer='RandomNormal')(x)

    t_predictions = Dense(units=1, activation='sigmoid')(x)

    # HYPOTHESIS
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2))(x)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2))(x)

    # second layer
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2))(y0_hidden)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2))(y1_hidden)

    # third
    y0_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y0_predictions')(
        y0_hidden)
    y1_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y1_predictions')(
        y1_hidden)

    dl = EpsilonLayer()
    epsilons = dl(t_predictions, name='epsilon')
    # logging.info(epsilons)
    concat_pred = Concatenate(1)([y0_predictions, y1_predictions, t_predictions, epsilons])
    model = Model(inputs=inputs, outputs=concat_pred)

    return model


def make_dragonnet0(input_dim, reg_l2):
    """
    Neural net predictive model. The dragon has three heads.
    :param input_dim:
    :param reg:
    :return:
    """
    t_l1 = 0.
    t_l2 = reg_l2
    inputs = Input(shape=(input_dim,), name='input')

    # representation
    x = Dense(units=200, activation='elu', kernel_initializer='RandomNormal')(inputs)
    x = Dense(units=200, activation='elu', kernel_initializer='RandomNormal')(x)
    x = Dense(units=200, activation='elu', kernel_initializer='RandomNormal')(x)

    t_predictions = Dense(units=1, activation='sigmoid')(inputs)

    # HYPOTHESIS
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2))(x)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2))(x)

    # second layer
    y0_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2))(y0_hidden)
    y1_hidden = Dense(units=100, activation='elu', kernel_regularizer=regularizers.l2(reg_l2))(y1_hidden)

    # third
    y0_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y0_predictions')(
        y0_hidden)
    y1_predictions = Dense(units=1, activation=None, kernel_regularizer=regularizers.l2(reg_l2), name='y1_predictions')(
        y1_hidden)

    dl = EpsilonLayer()
    epsilons = dl(t_predictions, name='epsilon')
    # logging.info(epsilons)
    concat_pred = Concatenate(1)([y0_predictions, y1_predictions, t_predictions, epsilons])
    model = Model(inputs=inputs, outputs=concat_pred)

    return model


def _split_output(yt_hat, t, y, y_scaler, x, index):
    q_t0 = y_scaler.inverse_transform(yt_hat[:, 0].copy())
    q_t1 = y_scaler.inverse_transform(yt_hat[:, 1].copy())
    g = yt_hat[:, 2].copy()
    eps = yt_hat[:, 3][0]

    y = y_scaler.inverse_transform(y.copy())

    var = "average propensity for treated: {} and untreated: {}".format(g[t.squeeze() == 1.].mean(),
                                                                        g[t.squeeze() == 0.].mean())
    print(var)

    return {'q_t0': q_t0, 'q_t1': q_t1, 'g': g, 't': t, 'y': y, 'x': x, 'index': index, 'eps': eps}


def train_and_predict(t, y_unscaled, x, targeted_regularization=True, output_dir='',
                      knob_loss=dragonnet_loss_binarycross, ratio=1., dragon=1):
    y_scaler = StandardScaler().fit(y_unscaled)
    y = y_scaler.transform(y_unscaled)

    train_outputs = []
    test_outputs = []

    for i in range(1):
        if dragon == 0:
            dragonnet = make_dragonnet0(x.shape[1], 0.01)
        elif dragon == 1:

            dragonnet = make_dragonnet1(x.shape[1], 0.01)

        metrics = [regression_loss, binary_classification_loss, treatment_accuracy, track_epsilon]

        if targeted_regularization:

            loss = make_tarreg_loss(ratio=ratio, dragonnet_loss=knob_loss)
        else:
            loss = knob_loss

        tf.random.set_random_seed(i)
        np.random.seed(i)
        train_index, test_index = train_test_split(np.arange(x.shape[0]), test_size=0.3)
        x_train, x_test = x[train_index], x[test_index]
        y_train, y_test = y[train_index], y[test_index]
        t_train, t_test = t[train_index], t[test_index]
        yt_train = np.concatenate([y_train, t_train], 1)

        import time;
        start_time = time.time()

        dragonnet.compile(
            optimizer=Adam(lr=1e-3),
            loss=loss, metrics=metrics)

        adam_callbacks = [
            TerminateOnNaN(),
            EarlyStopping(monitor='val_loss', patience=2, min_delta=0.),
            ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=1, mode='auto',
                              min_delta=1e-8, cooldown=0, min_lr=0)

        ]

        dragonnet.fit(x_train, yt_train, callbacks=adam_callbacks,
                      validation_split=0.2,
                      epochs=100,
                      batch_size=512, verbose=0)

        sgd_callbacks = [
            TerminateOnNaN(),
            EarlyStopping(monitor='val_loss', patience=40, min_delta=0.),
            ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, verbose=1, mode='auto',
                              min_delta=0., cooldown=0, min_lr=0)]

        # should pick something better!
        sgd_lr = 1e-5
        momentum = 0.9
        dragonnet.compile(optimizer=SGD(lr=sgd_lr, momentum=momentum, nesterov=True), loss=loss,
                          metrics=metrics)
        dragonnet.fit(x_train, yt_train, callbacks=sgd_callbacks,
                      validation_split=0.2,
                      epochs=300,
                      batch_size=512, verbose=0)

        elapsed_time = time.time() - start_time
        print("***************************** elapsed_time is: ", elapsed_time)

        yt_hat_test = dragonnet.predict(x_test)
        yt_hat_train = dragonnet.predict(x_train)

        test_outputs += [_split_output(yt_hat_test, t_test, y_test, y_scaler, x_test, test_index)]
        train_outputs += [_split_output(yt_hat_train, t_train, y_train, y_scaler, x_train, train_index)]
        K.clear_session()

    return test_outputs, train_outputs


def run_experiment(data_base_dir='~/data/ihdp_csv', output_base_dir='~/result/ihdp/',
                   knob_loss=dragonnet_loss_binarycross,
                   ratio=1., dragon=1):
    print(" I am here")
    simulation_files = sorted(glob.glob("{}/*.csv".format(data_base_dir)))

    for idx, simulation_file in enumerate(simulation_files):
        if idx < 1000:

            print("******************************************")
            print("Index: {}".format(idx))

            print(simulation_file)
            print("******************************************")
            simulation_output_dir = os.path.join(output_base_dir, str(idx))
            os.makedirs(simulation_output_dir, exist_ok=True)

            x = load_and_format_covariates(simulation_file)
            t, y, y_cf, mu_0, mu_1 = load_all_other_crap(simulation_file)
            np.savez_compressed(os.path.join(simulation_output_dir, "simulation_outputs.npz"),
                                t=t, y=y, y_cf=y_cf, mu_0=mu_0, mu_1=mu_1)

            for is_targeted_regularization in [True, False]:
                print("Is targeted regularization: {}".format(is_targeted_regularization))
                # fit the model, make predictions on test and train sets, and return those to us
                # this defaults to running the procedure on 5 splits
                test_outputs, train_output = train_and_predict(t, y, x,
                                                               targeted_regularization=is_targeted_regularization,
                                                               output_dir=simulation_output_dir,
                                                               knob_loss=knob_loss, ratio=ratio, dragon=dragon)

                if is_targeted_regularization:
                    train_output_dir = os.path.join(simulation_output_dir, "targeted_regularization")
                else:
                    train_output_dir = os.path.join(simulation_output_dir, "baseline")
                os.makedirs(train_output_dir, exist_ok=True)

                # save the outputs of for each split (1 per npz file)
                for num, output in enumerate(test_outputs):
                    np.savez_compressed(os.path.join(train_output_dir, "{}_replication_test.npz".format(num)),
                                        **output)

                for num, output in enumerate(train_output):
                    np.savez_compressed(os.path.join(train_output_dir, "{}_replication_train.npz".format(num)),
                                        **output)


def turn_knob(data_base_dir='~/data/ihdp_csv', knob='default',
              output_base_dir='~/result/ihdp/'):
    simulation_output_dir = os.path.join(output_base_dir, knob)

    if knob == 'ratio_0.01':
        run_experiment(data_base_dir=data_base_dir, output_base_dir=simulation_output_dir, ratio=0.01)

    if knob == 'default':
        run_experiment(data_base_dir=data_base_dir, output_base_dir=simulation_output_dir)

    if knob == 'tarnet':
        run_experiment(data_base_dir=data_base_dir, output_base_dir=simulation_output_dir, dragon=0)


def main():
    np.random.seed(42)
    tf.set_random_seed(42)
    # data_base_dir='', output_base_dir=''
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_base_dir', type=str, help="path to directory ihdp")
    parser.add_argument('--knob', type=str, default='early_stopping',
                        help="early_stopping/with or loss/mse or ratio/0")

    parser.add_argument('--output_base_dir', type=str, help="directory to save the output")

    args = parser.parse_args()

    turn_knob(args.data_base_dir, args.knob, args.output_base_dir)
    # run_experiment()


if __name__ == '__main__':
    main()
