import numpy as np
import scipy as sp
import scipy.special as sc
from scipy.special import gamma
from scipy import integrate
from scipy.integrate import dblquad, quad
from scipy.stats import gaussian_kde, differential_entropy, norm, ortho_group,multivariate_normal
import scipy.stats as sp_stats
import heapq
import matplotlib
import matplotlib.pyplot as plt
import time
import math
from math import ceil, floor, log2
import random





def init_rng(rng):
    return np.random.default_rng(rng)


class PoiPrivateRepr:
    def __init__(self, encoder=True, decoder=True, Q=None, rng_common=None, rng_private=None, a=3.0):
        """ Initialize encoder/decoder
        encoder: Whether this can be used as encoder
        decoder: Whether this can be used as decoder
        Q: A function taking an RNG as input, and output a sample following distribution Q
        rng_common: RNG shared between encoder and decoder
        rng_private: Private RNG only used by this encoder/decoder
        a: Parameter alpha of PPR
        """

        self.encoder = encoder
        self.decoder = decoder
        self.Q = Q
        self.rng_common = init_rng(rng_common)
        self.rng_private = init_rng(rng_private)
        self.a = a


    def encode(self, r, r_bd):
        """ Perform encoding
        r: Function that gives the ratio dP/dQ
        r_bd: An upper-bound on the values of r
        Returns: Pair (k, z) where k is the index and z is the sample
        """

        a = self.a
        Q = self.Q
        rng_c = init_rng(self.rng_common.integers(2**60))
        rng_p = self.rng_private
        u = 0
        ws = np.inf
        k = 0
        ks = 0
        zs = 0.0
        n = 0
        g1 = sp.special.gammainc(1 - 1/a, 1) * sp.special.gamma(1 - 1/a)
        h = []

        sprob = (1/np.e) / (1/np.e + g1)

        while True:
            u += rng_p.exponential()
            b = (u * a / (1/np.e + g1)) ** a
            bpia = b ** (1/a)

            if n == 0 and b * r_bd**-a >= ws:
                return (ks, zs)

            if rng_p.random() < sprob:
                t = bpia
                v = rng_p.exponential() + 1
            else:
                v = 2
                while v > 1:
                    v = rng_p.gamma(1 - 1/a)

                t = bpia / v**(1/a)

            th = 1 if (t / r_bd) ** a * v <= ws else 0
            heapq.heappush(h, (t, v, th))
            n += th

            while h and h[0][0] <= bpia:
                t, v, th = heapq.heappop(h)
                n -= th
                k += 1
                z = Q(rng_c)
                w = (t / r(z)) ** a * v
                if w < ws:
                    ws = w
                    ks = k
                    zs = z

    def decode(self, k):
        """ Perform decoding
        k: index
        Returns: Sample z
        """

        Q = self.Q
        rng_c = init_rng(self.rng_common.integers(2**60))
        z = None
        for i in range(k):
            z = Q(rng_c)
        return z


def bsearch(f, y, a, b, n=60):
    for _ in range(n):
        m = (a + b) * 0.5
        if f(m) > y:
            b = m
        else:
            a = m
    return (a + b) * 0.5


if __name__ == "__main__":
    delta_target = 10 ** -6
    n = 500  # number of users
    a = 2  # parameter 'alpha' for PPR
    bern_bias = 0.5
    num_trial = 2000  # number of trials for recording running time of each chunk
    eps = 0.05  # privacy parameter

    d_time_mean_chunk = []  # list for mean of running time
    d_time_std_chunk = []  # list for standard deviation of running time
    plt.figure(figsize=(12, 9))
    d_list = [40, 50, 60, 70, 80, 90, 100, 110]  # list of dimensions of chunks

    for i in range(len(d_list)):
        d = d_list[i]  # dimension for chunk

        # CSGM paper (reference [18] in our paper) "Privacy amplification via compression: Achieving the optimal privacy-accuracy-communication trade-off in distributed mean estimation."
        local_sigma = 24.410434445098872  # from the method used in CSGM paper (reference [18] in our paper)

        Pf1 = lambda x: sp.stats.norm.pdf(x, loc=1, scale=local_sigma)
        Qf1 = lambda x: (sp.stats.norm.pdf(x, loc=1.0, scale=local_sigma) + sp.stats.norm.pdf(x, loc=-1.0, scale=local_sigma)) / 2

        def r1(x):
            P = Pf1(x)
            if P < 1e-11:
                return 0.0
            return P / Qf1(x)

        r_bd = (-sp.optimize.minimize(lambda x: -r1(x), 0.0).fun + 1e-8) ** d

        time_list = []
        for trial in range(num_trial):
            time0 = time.time()

            x_bern = [1 if random.random() < bern_bias else -1 for _ in range(d)]
            Pf = lambda x: np.prod([sp.stats.norm.pdf(x[i], loc=x_bern[i], scale=local_sigma) for i in range(d)])
            Qf = lambda x: np.prod([(sp.stats.norm.pdf(x[i], loc=1.0, scale=local_sigma) + sp.stats.norm.pdf(x[i], loc=-1.0, scale=local_sigma))/2 for i in range(d)])
            Q = lambda rng: [rng.normal(loc=0.0, scale=local_sigma) + (1 if rng.random() < 0.5 else -1) for _ in range(d)]

            r = lambda x: Pf(x) / Qf(x)
            ppr = PoiPrivateRepr(Q=Q, a=a)

            encode_k, encode_z = ppr.encode(r, r_bd)

            time1 = time.time()
            time_list.append((time1 - time0))

        print("For d =", d, " average time is ", np.mean(time_list), " and std = ", np.std(time_list))
        m = np.mean(time_list)
        s = np.std(time_list, ddof=1)
        print(f"Error bar: for d = {d}", (m-2*(s / np.sqrt(num_trial)), m+2*(s / np.sqrt(num_trial))))
        d_time_mean_chunk.append(m)
        d_time_std_chunk.append(s)

    d_time_std_error_chunk = np.array(d_time_std_chunk) / np.sqrt(num_trial)
    plt.semilogy(d_list, d_time_mean_chunk, marker='*', color='blue', linestyle='solid')
    plt.bar(d_list, d_time_mean_chunk)
    plt.errorbar(d_list, d_time_mean_chunk, yerr=[std_err*2 for std_err in d_time_std_error_chunk], fmt="o", color="r", capsize=10, elinewidth=2, markeredgewidth=5)

    plt.rcParams['text.usetex'] = True
    plt.grid(which='major', axis='both', )
    plt.yscale('log')
    plt.ylabel('running time (seconds)', fontsize=26)
    plt.xlabel('chunk dimension $d_{\mathrm{chunk}}$', fontsize=26)
    # plt.title(f'running time', fontsize=24)
    plt.legend(fontsize=18.0, loc="lower left")
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.show()

