# -*- coding: utf-8 -*-
"""
Created on Fri Apr  5 14:29:09 2024

@author: ZJ
"""

import numpy as np
import pandas as pd
from runestimator import run, run_twostage, run_unbalanced,\
    generate, generate_randdiv
from utils import distance
from tqdm import tqdm

n = 1000
epsilon = 1
delta = 1e-5
distribution = 8
"""
Distributions:
    1: 1d uniform in [-1,1]
    2: 1d Gaussian in [-1,1]
    3: 2d uniform 
"""
method = 1
"""
method:
    1: Huber loss minimization (current)
    2: two stage approach (baseline)
"""
unbalanced = False
"""
If true, divide the samples randomly
If false, divide the samples uniformly
"""
n_trials = 100

narray = [1000,10000]
#narray = [10000]
#marray = [100]
marray = [1,2,5,10,20,50,100,200,500,1000]
"""
If balanced, values in marray is just the m value.
If unbalanced, values in marray is the average number of samples per user,
i.e. N = nm, some users have more samples than m, while others have less samples
than m.
"""
if distribution in [9,10]:
    d_ipums = pd.read_csv("E:/data/ipums.csv") #load the data
res = dict()
res["m"] = marray
for n in narray:
    msearray = []
    stdarray = []
    for m in marray:
        errs = []
        if distribution == 1:
            T = 5/np.sqrt(m)
            tau = 3/np.sqrt(m)
            Rc = 1
            mu = np.zeros(1)
        elif distribution == 2:
            T = 10/np.sqrt(m)
            tau = 7/np.sqrt(m)
            Rc = 1
            mu = np.zeros(1)
        elif distribution == 3:
            T = 15/np.sqrt(m)
            tau = 6/np.sqrt(m)
            Rc = 3
            mu = np.ones(1)
        elif distribution == 4:
            T = 10/np.sqrt(m)
            tau = 7/np.sqrt(m)
            Rc = 1
            mu = np.ones(3)
        elif distribution == 5:
            T = 8/np.sqrt(m)
            tau = 4/np.sqrt(m)
            Rc = 1
            mu = np.zeros(3)
        elif distribution == 6:
            T = 15/np.sqrt(m)
            tau = 10/np.sqrt(m)
            Rc = 1
            mu = np.zeros(3)
        elif distribution == 7:
            T = 40 * m**(-3/4)
            tau = 20/np.sqrt(m)
            Rc = 1
            mu = (1/3)*np.ones(1)
        elif distribution == 8:
            T = 25 * m**(-3/4)
            #tau = 20/np.sqrt(m)
            tau = 5/np.sqrt(m)
            Rc = 1
            mu = (1/3)*np.ones(3)
        elif distribution == 9:
            if n==1000:
                T = 2e6 * m**(-0.75)
                tau = 1e6/np.sqrt(m)
            elif n==10000:
                T = 4e6 * m**(-0.75)
                tau = 2e6/np.sqrt(m)
            #tau = 1e6/np.sqrt(m)
            Rc = 1e5
            mu = (51291.25)*np.ones(1)
        elif distribution == 10:
            if n==1000:
                T = 2e6 * m**(-0.75)
                tau = 1e6/np.sqrt(m)
            elif n==10000:
                T = 4e6 * m**(-0.75)
                tau = 2e6/np.sqrt(m)
            #tau = 1e6/np.sqrt(m)
            Rc = 1e5
            mu = (36267.59)*np.ones(1)            
        print("T=", T)
        for i in tqdm(range(n_trials)):
            if not unbalanced:
                if distribution not in [9,10]:
                    X = generate(n, m, distribution)
                elif distribution == 9:
                    X = d_ipums['INCTOT']
                    X = X[X!=9999999]
                    X = np.random.choice(X, size = n * m).reshape((n,m, 1))   
                elif distribution == 10:
                    X = d_ipums['INCWAGE']
                    X = X[X!=999999]
                    X = np.random.choice(X, size = n * m).reshape((n,m, 1))  
                D = np.mean(X, axis = 1) # average over samples.  
            else:
                D, m_vec = generate_randdiv(n, m, distribution)
            if distribution in [1,2]:
                Rc = 1
                mu = np.zeros(1)
            if method == 1:
                if not unbalanced:
                    ans, mu0, randerr = run(D, T, Rc, epsilon, delta)
                else:
                    Tarray = T * np.sqrt(m)/np.sqrt(m_vec)
                    ans, mu0, randerr = run_unbalanced(D, Tarray, m_vec, Rc, epsilon, delta)
            elif method == 2:
                if not unbalanced:
                    ans, mu0, randerr = run_twostage(D, tau, Rc, epsilon, delta)
                else:
                    ans, mu0, randerr = run_twostage(D, 3*tau, Rc, epsilon, delta)
            err = distance(mu0, mu) ** 2 + randerr
            errs.append(err)
        mse = np.mean(np.array(errs))
        std = np.std(np.array(errs))
        print("m={}, mse: {}, std: {}".format(m, mse, std))
        msearray.append(mse)
        stdarray.append(std)
    res[n] = np.array(msearray)
    res["std_{}".format(n)] = np.array(stdarray)

res = pd.DataFrame(res)
if not unbalanced:
    if method == 1:
        res.to_csv('result_{}_{}_new.csv'.format(distribution, epsilon), index = None)
    elif method == 2:
        res.to_csv("result_{}_{}_baseline.csv".format(distribution, epsilon), index = None)
else:
    if method == 1:
        res.to_csv('result_{}_{}_unbal_new.csv'.format(distribution, epsilon), index = None)
    elif method == 2:
        res.to_csv("result_{}_{}_unbal_baseline.csv".format(distribution, epsilon), index = None) 
    