import numpy as np
import torch

from nnlib.nnlib import utils
from methods import LangevinDynamics

# Import added
import scipy.optimize as opt

def discrete_mi_est(xs, ys, nx=2, ny=2):
    prob = np.zeros((nx, ny))
    for a, b in zip(xs, ys):
        prob[a,b] += 1.0/len(xs)
    pa = np.sum(prob, axis=1)
    pb = np.sum(prob, axis=0)
    mi = 0
    for a in range(nx):
        for b in range(ny):
            if prob[a,b] < 1e-9:
                continue
            mi += prob[a,b] * np.log(prob[a,b]/(pa[a]*pb[b]))
    return max(0.0, mi)


def estimate_fcmi_bound_classification(masks, preds, num_examples, num_classes,
                                       verbose=False, return_list_of_mis=False):
    bound = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            ps[i] = num_classes * ps[i][0] + ps[i][1]
            ps[i] = ps[i].item()
        cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        list_of_mis.append(cur_mi)
        bound += np.sqrt(2 * cur_mi)
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    bound *= 1/num_examples

    if return_list_of_mis:
        return bound, list_of_mis

    return bound

def kl(q,p):
    if q>0:
        return q*np.log(q/p) + (1-q)*np.log( (1-q)/(1-p) )
    else:
        return np.log( 1/(1-p) )

# Function added
def estimate_interp_bound_classification(masks, preds, num_examples, num_classes, train_acc,
                                       verbose=False, return_list_of_mis=False):
    RHS = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            ps[i] = num_classes * ps[i][0] + ps[i][1]
            ps[i] = ps[i].item()
        cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        list_of_mis.append(cur_mi)
        RHS += cur_mi
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    RHS *= 1/num_examples

    Rhat = 1-train_acc
    if Rhat == 0:
        bound = RHS/np.log(2)
    else:
        bound = 1
    if return_list_of_mis:
        return bound, list_of_mis

    return bound

# Function added
def estimate_kl_bound_classification(masks, preds, num_examples, num_classes, train_acc,
                                       verbose=False, return_list_of_mis=False):
    RHS = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            ps[i] = num_classes * ps[i][0] + ps[i][1]
            ps[i] = ps[i].item()
        cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        list_of_mis.append(cur_mi)
        RHS += cur_mi
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    RHS *= 1/num_examples

    Rhat = 1-train_acc
    # Constraints are expressions that should be non-negative
    # Below factors guarantee R<=1, R>=0,and bound satisfied
    def con(R):
        return (RHS-kl(Rhat,Rhat/2 + R/2))*R*(1-R)

    # Minimize -R to find biggest R that satisfies constraints
    objective = lambda R: -R
    cons = ({'type': 'ineq', 'fun' : con})
    results = opt.minimize(objective,x0=0.5,
    constraints = cons,
    options = {'disp':False})

    bound = results.x[0]
    if return_list_of_mis:
        return bound, list_of_mis

    return bound

# Function added
def estimate_lg_bound_classification(masks, preds, num_examples, num_classes, train_acc,
                                       verbose=False, return_list_of_mis=False):
    RHS = 0.0
    list_of_mis = []
    for idx in range(num_examples):
        ms = [p[idx] for p in masks]
        ps = [p[2*idx:2*idx+2] for p in preds]
        for i in range(len(ps)):
            ps[i] = torch.argmax(ps[i], dim=1)
            ps[i] = num_classes * ps[i][0] + ps[i][1]
            ps[i] = ps[i].item()
        cur_mi = discrete_mi_est(ms, ps, nx=2, ny=num_classes**2)
        list_of_mis.append(cur_mi)
        RHS += cur_mi
        if verbose and idx < 10:
            print("ms:", ms)
            print("ps:", ps)
            print("mi:", cur_mi)
    RHS *= 1/num_examples

    Rhat = 1-train_acc
    def con(x):
        return (-x[0]*(1-x[1]) - (np.exp(x[0]) - 1 - x[0] ) * ( 1 + x[1]**2 ))
    objective = lambda x: x[1]*Rhat + RHS/x[0]
    cons = ({'type': 'ineq', 'fun' : con})
    bnds = ((0, 0.37),(1,np.inf))
    results = opt.minimize(objective,x0=[3,2],
                           constraints = cons,
                           bounds = bnds,
                           options = {'disp':True})

    bound = results.x[1]*Rhat + RHS/results.x[0]

    if return_list_of_mis:
        return bound, list_of_mis

    return bound





def estimate_sgld_bound(n, batch_size, model):
    """ Computes the bound of Negrea et al. "Information-Theoretic Generalization Bounds for
    SGLD via Data-Dependent Estimates". Eq (6) of https://arxiv.org/pdf/1911.02151.pdf.
    """
    assert isinstance(model, LangevinDynamics)
    assert model.track_grad_variance
    T = len(model._grad_variance_hist)
    assert len(model._lr_hist) == T + 1
    assert len(model._beta_hist) == T + 1
    ret = 0.0
    for t in range(1, T):  # skipping the first iteration as grad_variance was not tracked for it
        ret += model._lr_hist[t] * model._beta_hist[t] / 4.0 * model._grad_variance_hist[t-1]
    ret = np.sqrt(utils.to_numpy(ret))
    ret *= np.sqrt(n / 4.0 / batch_size / (n-1) / (n-1))
    return ret
