import numpy as np
import json
import argparse
from datasets_directory.dataset_loader import Mydatasets
from my_logistic_regression import MyLogisticRegression
from opt_algs  import DoubleNoiseMech, CompareAlgs
from myutils import eps_to_zcdp


def helper_fun(datasetname,pb,num_rep):
    """ This function is a helper function for running different algorithms

    datasetname = name of the dataset
    pb = a dictionary with the parameters
    num_rep = number of times we repeat the optimization algorithm to report the average
    """
    datasets = Mydatasets()
    X,y,w_opt = getattr(datasets,datasetname)()
    dataset = X,y
    privacy_dp = pb["total"]
    num_samples = len(y)
    delta = (1.0/num_samples)**2
    rho_eq = eps_to_zcdp(privacy_dp,delta)
    pb["total"] = rho_eq
    print("equaivalent rho-zCDP: " + str(pb["total"]))
    lr = MyLogisticRegression(X,y,reg=1e-8)
    dnm_hess_add = DoubleNoiseMech(lr,type_reg='add',hyper_tuning=False,curvature_info='hessian').update_rule
    dnm_ub_add = DoubleNoiseMech(lr,type_reg='add',hyper_tuning=False,curvature_info='ub').update_rule
    dnm_hess_clip = DoubleNoiseMech(lr,type_reg='clip',hyper_tuning=False,curvature_info='hessian').update_rule
    dnm_ub_clip = DoubleNoiseMech(lr,type_reg='clip',hyper_tuning=False,curvature_info='ub').update_rule
    c = CompareAlgs(lr,dataset,w_opt,iters=pb["num_iteration"],pb=pb)
    for rep in range(num_rep):
        print(str(rep+1)+" expriment out of "+ str(num_rep))
        c.add_algo(dnm_hess_add,"DN-Hess-add")
        c.add_algo(dnm_hess_clip,"DN-Hess-clip")
        c.add_algo(dnm_ub_clip,"DN-UB-clip")
        c.add_algo(dnm_ub_add,"DN-UB-add")
        losses_dict = c.loss_vals()
        gradnorm_dict = c.gradnorm_vals()
        accuracy_dict = c.accuracy_vals()
        wall_clock_dict = c.wall_clock_alg()
        if rep == 0:
            losses_total = losses_dict
            gradnorm_total = gradnorm_dict
            accuracy_total = accuracy_dict
            wall_clock_total = wall_clock_dict
        else:
            for names in losses_total.keys():
                losses_total[names].extend(losses_dict[names])
                gradnorm_total[names].extend(gradnorm_dict[names])
                accuracy_total[names].extend(accuracy_dict[names])
                wall_clock_total[names].extend(wall_clock_dict[names])

    result = {}
    accuracy_wopt = c.accuracy_np()
    result['num-samples'] = num_samples
    result['acc-best'] = accuracy_wopt.tolist()
    for alg in losses_total.keys():
        losses = np.array(losses_total[alg])
        gradnorm = np.array(gradnorm_total[alg])
        acc = np.array(accuracy_total[alg])
        wall_clock = np.array(wall_clock_total[alg])
        result[alg] = {}
        result[alg] = {}
        result[alg]["loss_avg"] = (np.mean(losses, axis=0)).tolist()
        result[alg]["loss_std"] = (np.std(losses, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["gradnorm_avg"] = (np.mean(gradnorm, axis=0)).tolist()
        result[alg]["gradnorm_std"] = (np.std(gradnorm, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["acc_avg"] = (np.mean(acc, axis=0)).tolist()
        result[alg]["acc_std"] = (np.std(acc, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["clock_time_avg"] = np.mean(wall_clock, axis=0).tolist()
        result[alg]["clock_time_std"] =  (np.std(wall_clock, axis=0) / np.sqrt(num_rep)).tolist()

    json.dump(result, open("results/so_"+datasetname+"_"+str(privacy_dp)+"_"+'DP'+"_"+str(pb["num_iteration"])+"_"+str(pb["grad_frac"])+"_"+str(pb["trace_frac"])+"_"+str(pb["trace_coeff"])+".txt", 'w'))



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--datasetname")
    parser.add_argument("--total")
    parser.add_argument("--numiter")
    parser.add_argument("--grad_frac")
    parser.add_argument("--trace_frac")
    parser.add_argument("--trace_coeff")
    args = parser.parse_args()
    datasetname = args.datasetname
    total = float(args.total) # total privacy budget 
    num_iter = int(args.numiter)  # number of iterations
    grad_frac = float(args.grad_frac)
    trace_frac = float(args.trace_frac)
    trace_coeff = float(args.trace_coeff)
    pb = {
      "total": total,  # Total privacy budget
      "grad_frac": grad_frac,  # Fraction of privacy budget for gradient vs search direction
      "trace_frac": trace_frac,
      "trace_coeff": trace_coeff,
      "num_iteration": num_iter
    }
    num_rep = 10 # the number of repetitions for averaging over the randomness 
    helper_fun(datasetname,pb,num_rep=num_rep)


if __name__ == '__main__':
    main()
