import numpy as np
import json
import argparse
from datasets_directory.dataset_loader import Mydatasets
from opacus.accountants.utils import get_noise_multiplier
from my_logistic_regression import MyLogisticRegression
from opt_algs  import DoubleNoiseMech, CompareAlgs


def helper_fun(datasetname,pb,num_rep):
    """ This function is a helper function for running different algorithms

    datasetname = name of the dataset
    pb = a dictionary with the parameters
    num_rep = number of times we repeat the optimization algorithm to report the average
    Tuning = True or False exhustive search for finding the best min eigenvalue
    """
    datasets = Mydatasets()
    X,y,w_opt = getattr(datasets,datasetname)()
    dataset = X,y
    priv_param = pb["total"]
    num_samples = len(y)
    num_iters = pb["num_iteration"]
    delta = (1.0/num_samples)**2
    frac_grad = pb['grad_frac']
    
    if pb["batchsize_grad"] == 'full':
        batchsize_grad = num_samples
        pb["batchsize_grad"] = batchsize_grad
    else:
        batchsize_grad = int(pb["batchsize_grad"])
        pb["batchsize_grad"] = batchsize_grad
    
    print("batch size gradient is: "+str(batchsize_grad))

    if pb["batchsize_hess"] == 'full':
        batchsize_hess = num_samples
        pb["batchsize_hess"] = batchsize_hess
    else:
        batchsize_hess = int(pb["batchsize_hess"])
        pb["batchsize_hess"] = batchsize_hess
    
    print("batch size hess is: "+str(batchsize_hess))

    std_grad = get_noise_multiplier(target_epsilon=frac_grad*priv_param, target_delta=frac_grad*delta, sample_rate=batchsize_grad/num_samples, 
                                        epochs=None, steps=num_iters, accountant='rdp', epsilon_tolerance=0.01)
    std_hess = get_noise_multiplier(target_epsilon=(1-frac_grad)*priv_param, target_delta=(1-frac_grad)*delta, sample_rate=batchsize_hess/num_samples, 
                                        epochs=None, steps=num_iters, accountant='rdp', epsilon_tolerance=0.01)
    pb['noise_multiplier_grad'] = std_grad
    pb['noise_multiplier_hess'] = std_hess
    lr = MyLogisticRegression(X,y,reg=1e-8)
    dnm_hess_add = DoubleNoiseMech(lr,type_reg='add',hyper_tuning=False,curvature_info='hessian').update_rule_stochastic
    dnm_ub_add = DoubleNoiseMech(lr,type_reg='add',hyper_tuning=False,curvature_info='ub').update_rule_stochastic
    dnm_hess_clip = DoubleNoiseMech(lr,type_reg='clip',hyper_tuning=False,curvature_info='hessian').update_rule_stochastic
    dnm_ub_clip = DoubleNoiseMech(lr,type_reg='clip',hyper_tuning=False,curvature_info='ub').update_rule_stochastic
    #eps = zcdp_to_eps(pb["total"],delta)
    c = CompareAlgs(lr,dataset,w_opt,iters=pb["num_iteration"],pb=pb)
    for rep in range(num_rep):
        print(str(rep+1)+" expriment out of "+ str(num_rep))
        c.add_algo(dnm_hess_add,"DN-Hess-add")
        c.add_algo(dnm_hess_clip,"DN-Hess-clip")
        c.add_algo(dnm_ub_clip,"DN-UB-clip")
        c.add_algo(dnm_ub_add,"DN-UB-add")

        losses_dict = c.loss_vals()
        gradnorm_dict = c.gradnorm_vals()
        accuracy_dict = c.accuracy_vals()
        wall_clock_dict = c.wall_clock_alg()
        if rep == 0:
            losses_total = losses_dict
            gradnorm_total = gradnorm_dict
            accuracy_total = accuracy_dict
            wall_clock_total = wall_clock_dict
        else:
            for names in losses_total.keys():
                losses_total[names].extend(losses_dict[names])
                gradnorm_total[names].extend(gradnorm_dict[names])
                accuracy_total[names].extend(accuracy_dict[names])
                wall_clock_total[names].extend(wall_clock_dict[names])

    result = {}
    accuracy_wopt = c.accuracy_np()
    result['num-samples'] = num_samples
    result['acc-best'] = accuracy_wopt.tolist()
    for alg in losses_total.keys():
        losses = np.array(losses_total[alg])
        gradnorm = np.array(gradnorm_total[alg])
        acc = np.array(accuracy_total[alg])
        wall_clock = np.array(wall_clock_total[alg])
        result[alg] = {}
        result[alg] = {}
        result[alg]["loss_avg"] = (np.mean(losses, axis=0)).tolist()
        result[alg]["loss_std"] = (np.std(losses, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["gradnorm_avg"] = (np.mean(gradnorm, axis=0)).tolist()
        result[alg]["gradnorm_std"] = (np.std(gradnorm, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["acc_avg"] = (np.mean(acc, axis=0)).tolist()
        result[alg]["acc_std"] = (np.std(acc, axis=0) / np.sqrt(num_rep)).tolist()
        result[alg]["clock_time_avg"] = np.mean(wall_clock, axis=0).tolist()
        result[alg]["clock_time_std"] =  (np.std(wall_clock, axis=0) / np.sqrt(num_rep)).tolist()

    json.dump(result, open("results-stochastic-new/so_"+datasetname+"_"+str(priv_param)+"_"+str(pb["num_iteration"])+"_"+str(pb["grad_frac"])+"_"+str(pb["batchsize_grad"])+"_"+str(pb["batchsize_hess"])+"_"+str(pb["min_eval"])+".txt", 'w'))



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("datasetname")
    parser.add_argument("total")
    parser.add_argument("numiter")
    parser.add_argument("grad_frac")
    parser.add_argument("batchsize_grad")
    parser.add_argument("batchsize_hess")
    parser.add_argument("min_eval")
    args = parser.parse_args()
    datasetname = args.datasetname
    total = float(args.total) # total privacy budget 
    num_iter = int(args.numiter)  # number of iterations
    grad_frac = float(args.grad_frac)
    min_eval = float(args.min_eval)
    batchsize_grad = args.batchsize_grad
    batchsize_hess = args.batchsize_hess
    pb = {
      "total": total,  # Total privacy budget
      "grad_frac": grad_frac,  # Fraction of privacy budget for gradient vs search direction
      "batchsize_grad": batchsize_grad,
      "batchsize_hess": batchsize_hess,
      "num_iteration": num_iter,
      "min_eval": min_eval
    }
    num_rep = 15 # the number of repetitions for averaging over the randomness 
    print("the dataset is ", str(datasetname))
    print('total is '+str(total)+' num_iter '+str(num_iter))
    helper_fun(datasetname,pb,num_rep=num_rep)


if __name__ == '__main__':
    main()
