import sys
import numpy as np
import pandas as pd
from scipy.optimize import fmin_ncg
from scipy import stats
import tensorflow as tf
from matplotlib import pyplot as plt
from sklearn.externals import joblib
from DataModule import MnistModule, NewsModule, AdultModule
import SGDInfluence
import TrainDNN

if __name__=='__main__':
    data_key = sys.argv[1]
    seed = int(sys.argv[2])
    gpu_idx = int(sys.argv[3])
    module, (n_tr, n_val, n_test), m, alpha, (lr, decay, num_epoch, batch_size) = TrainDNN.settings(data_key)
    beta = 0.1

    # fetch data
    z_tr, z_val, _, _ = module.fetch(n_tr, n_val, n_test, seed)
    (x_tr, y_tr), (x_val, y_val) = z_tr, z_val
    y_tr = y_tr[:, np.newaxis]
    y_val = y_val[:, np.newaxis]

    # load result
    res = joblib.load('./%s_dnn/sgd%03d.dat' % (data_key, seed))
    B = {}
    for key in res['sgd'].keys():
        B[key] = res['sgd'][key]['a']
    b = B.pop('noskip')

    # model
    tf.reset_default_graph()
    with tf.device('/gpu:%d' % (gpu_idx,)):
        input_tensor = tf.placeholder(tf.float32, shape=(None, x_tr.shape[1]))
        output_tensor = tf.placeholder(tf.float32, shape=(None, 1))
        sigmoid, logit, params = TrainDNN.build_dnn(input_tensor, output_tensor, m=m, seed=seed)
        loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=output_tensor, logits=logit))
        regularizer = 0.5 * alpha * tf.reduce_sum([tf.reduce_sum(p**2) for p in params])
        loss += regularizer
        graph = tf.get_default_graph()
        sess = tf.Session(graph=graph)
        # for the proposed method
        inf = SGDInfluence.SGDInfluence(graph, sess, input_tensor, output_tensor, sigmoid, loss, params, alpha=alpha)
        # for icml (w/ a stronger regularization beta > alpha)
        inf2 = SGDInfluence.SGDInfluence(graph, sess, input_tensor, output_tensor, sigmoid, loss, params, alpha=beta)
    sess.run(tf.global_variables_initializer())
    
    # LIE - true
    feed_dict = {input_tensor:x_val, output_tensor:y_val}
    feed_dict = SGDInfluence.concat_feed_entry(params, b, feed_dict=feed_dict)
    loss0 = sess.run(loss, feed_dict=feed_dict)
    keys = np.sort(list(B.keys()))
    loss_diff = []
    for key in keys:
        bk = B[key]
        feed_dict = {input_tensor:x_val, output_tensor:y_val}
        feed_dict = SGDInfluence.concat_feed_entry(params, bk, feed_dict=feed_dict)
        lossk = sess.run(loss, feed_dict=feed_dict)
        loss_diff.append(lossk - loss0)
    loss_diff = np.array(loss_diff)
    joblib.dump(loss_diff, './%s_dnn/loss_diff_true_%03d.dat' % (data_key, seed))
    
    # LIE - proposed
    val_dict = {input_tensor:x_val, output_tensor:y_val}
    val_dict = SGDInfluence.concat_feed_entry(params, b, feed_dict=feed_dict)
    with tf.device('/gpu:%d' % (gpu_idx,)):
        u = sess.run(tf.gradients(loss, params), feed_dict=val_dict)
    loss_diff_est = inf.infer_linear_influence(x_tr, y_tr, u, 
        num_epoch=num_epoch, batch_size=batch_size, epoch_used=num_epoch, 
        prefix='%s_dnn_seed0%03d' % (data_key, seed))
    joblib.dump(loss_diff_est, './%s_dnn/loss_diff_proposed_%03d.dat' % (data_key, seed))
    
    # LIE - icml
    def vec_to_list(z):
        z_list = []
        i = 0
        for p in params:
            z_list.append(z[i:i+np.prod(p.shape)].reshape(p.shape))
            i += np.prod(p.shape)
        assert i == z.size
        return z_list

    def fmin_loss(z):
        z_list = vec_to_list(z)
        feed_dict = {input_tensor:x_tr, output_tensor:y_tr}
        feed_dict = SGDInfluence.concat_feed_entry(params, b, feed_dict=feed_dict)
        feed_dict = SGDInfluence.concat_feed_entry(inf2.u, z_list, feed_dict=feed_dict)
        Hz = sess.run(inf2.hess_u, feed_dict=feed_dict)
        cg_obj = 0.5 * np.sum(SGDInfluence.np_dot_product(Hz, z_list)) - np.sum(SGDInfluence.np_dot_product(u, z_list))
        return cg_obj

    def fgrad_loss(z):
        z_list = vec_to_list(z)
        feed_dict = {input_tensor:x_tr, output_tensor:y_tr}
        feed_dict = SGDInfluence.concat_feed_entry(params, b, feed_dict=feed_dict)
        feed_dict = SGDInfluence.concat_feed_entry(inf2.u, z_list, feed_dict=feed_dict)
        Hz = sess.run(inf2.hess_u, feed_dict=feed_dict)
        cg_grad = np.concatenate([(hi-ui).flatten() for hi, ui in zip(Hz, u)])
        return cg_grad

    def fhess_loss(z, p):
        p_list = vec_to_list(p)
        feed_dict = {input_tensor:x_tr, output_tensor:y_tr}
        feed_dict = SGDInfluence.concat_feed_entry(params, b, feed_dict=feed_dict)
        feed_dict = SGDInfluence.concat_feed_entry(inf2.u, p_list, feed_dict=feed_dict)
        Hp = sess.run(inf2.hess_u, feed_dict=feed_dict)
        cg_hess = np.concatenate([hi.flatten() for hi in Hp])
        return cg_hess

    fargmin = fmin_ncg(f=fmin_loss, x0=np.concatenate([ui.flatten() for ui in u]),
                       fprime=fgrad_loss, fhess_p=fhess_loss, callback=None,
                       avextol=1e-8, maxiter=100)
    hvu = vec_to_list(fargmin)
    loss_diff_est2 = []
    for i in range(n_tr):
        feed_dict = {input_tensor:np.expand_dims(x_tr[i, :], 0), output_tensor:[y_tr[i]]}
        feed_dict = SGDInfluence.concat_feed_entry(params, b, feed_dict=feed_dict)
        gi = sess.run(inf.obj_grads, feed_dict=feed_dict)
        loss_diff_est2.append(np.sum(SGDInfluence.np_dot_product(hvu, gi)) / n_tr)
    loss_diff_est2 = np.array(loss_diff_est2)
    joblib.dump(loss_diff_est2, './%s_dnn/loss_diff_icml_%03d.dat' % (data_key, seed))
