import numpy as np
import json
import opacus
from opacus.accountants.utils import get_noise_multiplier
import argparse
from datasets_directory.dataset_loader import Mydatasets
from approx_op import ApproximateMinimaPerturbationLR
from my_logistic_regression import MyLogisticRegression
from opt_algs  import DoubleNoiseMech, CompareAlgs
from myutils import zcdp_to_eps, eps_to_zcdp


def helper_fun(datasetname,pb,num_rep):
    """ This function is a helper function for running different algorithms

    datasetname = name of the dataset
    pb = a dictionary with the parameters
    num_rep = number of times we repeat the optimization algorithm to report the average
    """
    datasets = Mydatasets()
    X,y,w_opt = getattr(datasets,datasetname)()
    dataset = X,y
    priv_param = pb["total"]
    num_samples = len(y)
    delta = (1.0/num_samples)**2
    lr = MyLogisticRegression(X,y,reg=1e-9)
    amp = ApproximateMinimaPerturbationLR().run_classification
    acc_objpert = []
    loss_objpert = []
    print("privacy constraint is DP!"+' eps: '+str(priv_param))
    for rep in range(num_rep):
        print(str(rep+1)+" expriment out of "+ str(num_rep))
        theta_objpert, _ = amp(X, y, priv_param, delta, lambda_param=None, l2_constraint=None) 
        acc_objpert.append(lr.accuracy(theta_objpert))
        loss_objpert.append(lr.loss_wor(theta_objpert)-lr.loss_wor(w_opt))
        
    result = {}
    accuracy_wopt = lr.accuracy(w_opt)
    result['num-samples'] = num_samples
    result['acc-best'] = accuracy_wopt.tolist()
    result['obj-perturb'] = {} 
    result['obj-perturb']["loss_avg"] = np.mean(loss_objpert).tolist()
    result['obj-perturb']["loss_std"] = (np.std(loss_objpert) / np.sqrt(num_rep)).tolist()
    result['obj-perturb']["acc_avg"] = np.mean(acc_objpert).tolist()
    result['obj-perturb']["acc_std"] = (np.std(acc_objpert) / np.sqrt(num_rep)).tolist()

    json.dump(result, open("results/op_"+datasetname+"_"+str(priv_param)+"_"+'DP'+".txt", 'w'))



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--datasetname")
    parser.add_argument("--total")
    args = parser.parse_args()
    datasetname = args.datasetname
    total = float(args.total)
    pb = {
      "total": total,  # Total privacy budget
    }
    num_rep = 5 # the number of repetitions for averaging over the randomness 
    helper_fun(datasetname,pb,num_rep=num_rep)

if __name__ == '__main__':
    main()
