from operator import matmul
import numpy as np
import math
from scipy.optimize import fsolve
import numpy.matlib
import scipy.stats
import scipy.linalg as scilin
from scipy.stats import multivariate_normal
from scipy.special import logsumexp
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import time
import random
from matplotlib.patches import Ellipse
from sklearn import mixture
import matplotlib.transforms as transforms
from functools import partial
import tensorflow as tf
from tensorflow.python.ops.numpy_ops import np_config
from scipy import optimize
np_config.enable_numpy_behavior()
# import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
###################################################################################
###########################Outline#################################################
# 1) Generate a Gaussian Mixture Model Parameters
# 2) Numerically Integrate MI
#       a) Break into seperate Entropy terms
# 3) Create General Variational approx. computations
#       a) Marginal and Conditional
# 4) Compute Moment Matching
# 5) Compute Gradient Ascent
# 6) Compute Barber and Agakov
# 8) Plot Samples and approximation
###################################################################################

def GaussianMixtureParams(M,Dx,Dy):
    ###############################################################################
    # Outline: Randomly Generates Parameters for GMM
    #
    # Inputs:
    #       M - Number of components
    #       Dx - Number of dimensions for Latent Variable, X
    #       Dy - Number of dimensions for Observation Variable, Y
    #
    # Outputs:
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    ###############################################################################
    D = Dx+Dy
    w = np.random.dirichlet(np.ones(M))
    mu = []
    sigma = []
    for d in range(M):
        mu.append(np.random.uniform(-5,5,(D,1)))
        A = np.random.rand(D, D)
        B = np.dot(A, A.transpose())
        sigma.append(B)
    return w,mu,sigma

def SampleGMM(N,w,mu,sigma):
    ###############################################################################
    # Outline: Samples Points from a GMM
    #
    # Inputs:
    #       N - Number of points to sample
    #       w - weights of GMM components
    #       mu - means of GMM components
    #       Sigma - Variance of GMM components
    #
    # Outputs:
    #       samples - coordniates of sampled points
    ###############################################################################
    samples = np.zeros((N,len(mu[0])))
    for j in range(N):
        acc_pis = [np.sum(w[:i]) for i in range(1, len(w)+1)]
        r = np.random.uniform(0, 1)
        k = 0
        for i, threshold in enumerate(acc_pis):
            if r < threshold:
                k = i
                break
        x = np.random.multivariate_normal(mu[k].T.tolist()[0],sigma[k].tolist())
        samples[j,:] = x
    return samples

def MargEntGMM(N,L,Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    x = np.linspace(-L,L,N)
    if Dx == 1:
        X=x 
    else:
        X1, X2 = np.meshgrid(x,x)
        X = np.vstack((X1.flatten(),X2.flatten()))

    MargEntPart = np.zeros((M,len(X.T)))
    for d in range(M):
        MargEntPart[d,:] = multivariate_normal.logpdf(X.T,mu[d][0:Dx].T.tolist()[0],Sigma[d][0:Dx,0:Dx])+np.log(w[d])
    if Dx == 1:
        MargEnt = -1*sum(np.sum(np.exp(MargEntPart),axis=0)*logsumexp(MargEntPart,axis=0))*2*L/N
    else:
        MargEnt = -1*sum(np.sum(np.exp(MargEntPart),axis=0)*logsumexp(MargEntPart,axis=0))*(2*L/N)**2
    return MargEnt

def CrossEntGMM(N,L,Dx,wout,muout,Sigmaout,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(wout)
    M1 = len(w)
    x = np.linspace(-L,L,N)
    if Dx == 1:
        X=x 
    else:
        X1, X2 = np.meshgrid(x,x)
        X = np.vstack((X1.flatten(),X2.flatten()))

    MargEntPartOut = np.zeros((M,len(X.T)))
    MargEntPartIn = np.zeros((M1,len(X.T)))
    for d in range(M):
        MargEntPartOut[d,:] = multivariate_normal.logpdf(X.T,muout[d][0:Dx].T.tolist()[0],Sigmaout[d][0:Dx,0:Dx])+np.log(wout[d])
    
    for d in range(M1):
        MargEntPartIn[d,:] = multivariate_normal.logpdf(X.T,mu[d][0:Dx].T.tolist()[0],Sigma[d][0:Dx,0:Dx])+np.log(w[d])
    
    if Dx == 1:
        MargEnt = -1*sum(np.sum(np.exp(MargEntPartOut),axis=0)*logsumexp(MargEntPartIn,axis=0))*2*L/N
    else:
        MargEnt = -1*sum(np.sum(np.exp(MargEntPartOut),axis=0)*logsumexp(MargEntPartIn,axis=0))*(2*L/N)**2
    return MargEnt


def MargEntGMMLimit(N,Dx,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M = len(w)
    TaylorEnt = np.zeros((1,N+1))
    Scale = np.zeros((M,1))
    Sigma_inv = []
    
    for i in range(M):
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
        Sigma_inv.append(np.linalg.inv(Sigma[i]))
    MaxConst = np.sum(Scale)
    TaylorEnt[0,0] += -np.log(MaxConst)
    for i in range(M):
        outer = 0 
        for n in range(1,N+1):
            middle = 0
            for k in range(n+1):
                Nmatrix, NCoef = multinomial_expand(n-k,M)
                inner = 0
                for t in range(len(NCoef)):
                    SumSigmaInv = np.matmul(Sigma_inv[i],np.eye(len(Sigma_inv[i])))
                    SumMu = np.matmul(Sigma_inv[i],mu[i])
                    SumInd = np.matmul(mu[i].T,np.matmul(Sigma_inv[i],mu[i]))
                    for j in range(M):
                        SumSigmaInv += Nmatrix[t,j]*Sigma_inv[j]
                        SumMu += Nmatrix[t,j]*np.matmul(Sigma_inv[j],mu[j])
                        SumInd += Nmatrix[t,j]*np.matmul(mu[j].T,np.matmul(Sigma_inv[j],mu[j]))
                    SumSigma = np.linalg.inv(SumSigmaInv)
                    SumMu = np.matmul(SumSigma,SumMu)
                    expTerm = np.exp(-.5*(-1*np.matmul(SumMu.T,np.matmul(SumSigmaInv,SumMu))+SumInd))
                    inner += NCoef[t]*np.prod((Scale.T)**Nmatrix[t])*expTerm*(np.linalg.det(2*np.pi*SumSigma)**(1/2))*(Scale[i]/w[i])
                combcoef = math.comb(n,k)
                middle += combcoef*inner*(-MaxConst)**k
            outer += (-1)**(n-1)/(n*MaxConst**n)*middle
            TaylorEnt[0,n] += -w[i]*(np.log(MaxConst)+outer)
    TaylorLimit = TaylorEnt[0,-3]-(TaylorEnt[0,-2]-TaylorEnt[0,-3])**2/(TaylorEnt[0,-1]-2*TaylorEnt[0,-2]+TaylorEnt[0,-3])
    return TaylorEnt, TaylorLimit

def CrossEntGMMLimit(N,Dx,wout,muout,Sigmaout,w,mu,Sigma):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       MargEnt - Marginal Entropy
    ###############################################################################
    M1 = len(wout)
    M = len(w)
    TaylorEnt = np.zeros((1,N+1))
    Scale = np.zeros((M,1))
    Sigma_inv = []
    
    for i in range(M):
        Scale[i] = w[i]*np.linalg.det(2*np.pi*Sigma[i])**(-1/2)
        Sigma_inv.append(np.linalg.inv(Sigma[i]))
    MaxConst = np.sum(Scale)
    TaylorEnt[0,0] += -np.log(MaxConst)
    for i in range(M1):
        Sigma_inv_out = np.linalg.inv(Sigmaout[i])
        scale_out = np.linalg.det(2*np.pi*Sigmaout[i])**(-1/2)
        outer = 0 
        for n in range(1,N+1):
            middle = 0
            for k in range(n+1):
                Nmatrix, NCoef = multinomial_expand(n-k,M)
                inner = 0
                for t in range(len(NCoef)):
                    SumSigmaInv = np.matmul(Sigma_inv_out,np.eye(len(Sigma_inv_out)))
                    SumMu = np.matmul(Sigma_inv_out,muout[i])
                    SumInd = np.matmul(muout[i].T,np.matmul(Sigma_inv_out,muout[i]))
                    for j in range(M):
                        SumSigmaInv += Nmatrix[t,j]*Sigma_inv[j]
                        SumMu += Nmatrix[t,j]*np.matmul(Sigma_inv[j],mu[j])
                        SumInd += Nmatrix[t,j]*np.matmul(mu[j].T,np.matmul(Sigma_inv[j],mu[j]))
                    SumSigma = np.linalg.inv(SumSigmaInv)
                    SumMu = np.matmul(SumSigma,SumMu)
                    expTerm = np.exp(-.5*(-1*np.matmul(SumMu.T,np.matmul(SumSigmaInv,SumMu))+SumInd))
                    inner += NCoef[t]*np.prod((Scale.T)**Nmatrix[t])*expTerm*(np.linalg.det(2*np.pi*SumSigma)**(1/2))*(scale_out)
                combcoef = math.comb(n,k)
                middle += combcoef*inner*(-MaxConst)**k
            outer += (-1)**(n-1)/(n*MaxConst**n)*middle
            TaylorEnt[0,n] += -wout[i]*(np.log(MaxConst)+outer)
    TaylorLimit = TaylorEnt[0,-3]-(TaylorEnt[0,-2]-TaylorEnt[0,-3])**2/(TaylorEnt[0,-1]-2*TaylorEnt[0,-2]+TaylorEnt[0,-3])
    return TaylorEnt, TaylorLimit


def multinomial_expand(pow,dim):
    ############ https://www.mathworks.com/matlabcentral/fileexchange/48215-multinomial-expansion
    NMatrix = multinomial_powers_recursive(pow,dim)
    powvec = np.matlib.repmat(pow,np.shape(NMatrix)[0],1)
    NCoef = np.floor(np.exp(scipy.special.gammaln(powvec+1).flatten() - np.sum(scipy.special.gammaln(NMatrix+1),1))+0.5)
    return NMatrix, NCoef

def multinomial_powers_recursive(pow,dim):
    if dim == 1:
        Nmatrix = np.array([[pow]])
    else:
        Nmatrix = []
        for pow_on_x1 in range(pow+1):
            newsubterms = multinomial_powers_recursive(pow-pow_on_x1,dim-1)
            new = np.hstack((pow_on_x1*np.ones((np.shape(newsubterms)[0],1)),newsubterms))
            if len(Nmatrix)==0:#Nmatrix == []:
                Nmatrix = new
            else:
                Nmatrix =np.vstack((Nmatrix, new))
            # Nmatrix = [Nmatrix; [pow_on_x1*ones(np.shape(newsubterms,1),1) , newsubterms] ]
    return Nmatrix

def KLDivergeExample(K,Ns):
    ###############################################################################
    # Outline: Numerically Calculates Marginal Entropy
    #
    # Inputs:
    #       samples - List of full sample set
    #       Dx - Dimension of Latenat Variable, X
    #       Dy - Dimension of Obsevation Variable, Y
    #       w - weights of components
    #       mu - means of components
    #       Sigma - Variance of components
    #
    # Outputs:
    #       CondEnt - Conditional Entropy
    ###############################################################################
    Dx = 2
    M = 2
    TrueEnt = np.zeros((K,1))
    TrueCrossEnt = np.zeros((K,1))
    TaylorEnt = np.zeros((K,len(Ns)))
    TaylorLimit = np.zeros((K,1))
    TaylorCrossEnt = np.zeros((K,len(Ns)))
    TaylorCrossLimit = np.zeros((K,1))
    i=0
    fig_list=[]
    for c in np.linspace(-3,3,K):
        c1 = ((c+3)/6)+.01
        ws = np.array([0.35, 0.65])
        mus = [np.array([[.5],[0]]), np.array([[2],[1]])]#
        sigmas = [2*np.eye(2),c1*np.eye(2)]
        
        wsout = np.array([0.2, 0.2, 0.2, 0.2 , 0.2])
        musout = [np.array([[0],[0]]), np.array([[3],[2]]), np.array([[1],[-.5]]), np.array([[2.5],[1.5]]),np.array([[c],[c]])]
        sigmasout = [np.diag((.16,1)),np.diag((1,.16)),np.diag((.5,.5)),np.diag((.5,.5)),np.diag((.5,.5))]
        # fig = plotGMMdensity(wsout,musout,sigmasout,ws, mus, sigmas)
        # fig.canvas.draw()
        # # convert the figure to an image and add it to the list
        # img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        # img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        # fig_list.append(img)
        # # close the figure to free up memory
        # plt.close(fig)
        
        TrueEnt[i] = MargEntGMM(1000,50,Dx,wsout,musout,sigmasout)
        TrueCrossEnt[i] = CrossEntGMM(1000,50,Dx,wsout,musout,sigmasout,ws,mus,sigmas)
        
        TaylorEnt[i,:], TaylorLimit[i] = MargEntGMMLimit(Ns[-1],Dx,wsout,musout,sigmasout)  
        TaylorCrossEnt[i,:], TaylorCrossLimit[i] = CrossEntGMMLimit(Ns[-1],Dx,wsout,musout,sigmasout,ws,mus,sigmas)
        i+=1
    return TrueEnt, TrueCrossEnt, TaylorEnt, TaylorLimit, TaylorCrossEnt, TaylorCrossLimit, fig_list

####################################################################################
############################## Plotting Functions ##################################
def plotGMMdensity(ws1, mus1, Sigmas1,ws2, mus2, Sigmas2):
    def multivariate_gaussian_pdf(x, mean, cov):
            """Multivariate Gaussian PDF"""
            k = mean.shape[0]
            det = np.linalg.det(cov)
            inv = np.linalg.inv(cov)
            norm = 1.0 / np.sqrt((2*np.pi)**k * det)
            exp = np.exp(-0.5 * np.dot(np.dot((x-mean).T, inv), x-mean))
            return norm * exp
    def gmm_pdf(x, weights, means, covs):
        """Gaussian Mixture Model PDF"""
        return np.sum(np.fromiter((weights[i] * multivariate_gaussian_pdf(x, means[i], covs[i])
                                for i in range(len(weights))), dtype=float))
    
    # ws = [0.2, 0.8]
    # mus = np.array([[-1, 1], [2, -1]])
    # Sigmas = np.array([[[1, 0.5], [0.5, 1]], [[1.5, -.5], [-.5, 1.5]]])
    
    ws1 = ws1.tolist()
    mus1hold = mus1[0].flatten()
    for i in range(len(mus1)-1):
        mus1hold = np.vstack((mus1hold,mus1[i+1].flatten()))
    Sigmas1hold = np.stack(Sigmas1,axis=0)
        
    ws2 = ws2.tolist()    
    mus2hold = mus2[0].flatten()
    for i in range(len(mus2)-1):
        mus2hold = np.vstack((mus2hold,mus1[i+1].flatten()))
    Sigmas2hold = np.stack(Sigmas2,axis=0)    
    

    # Define a grid to evaluate the density on
    x_min, x_max = -4, 6
    y_min, y_max = -4, 4
    xx, yy = np.mgrid[x_min:x_max:100j, y_min:y_max:100j]
    grid = np.c_[xx.ravel(), yy.ravel()]

    # Evaluate the density on the grid
    gmm1_z = np.array([gmm_pdf(x, ws1, mus1hold, Sigmas1hold) for x in grid])
    gmm1_z = gmm1_z.reshape(xx.shape)

    gmm2_z = np.array([gmm_pdf(x, ws2, mus2hold, Sigmas2hold) for x in grid])
    gmm2_z = gmm2_z.reshape(xx.shape)
    
    # Plot the contour plot of the GMM
    fig2, ax2 = plt.subplots()
    cntr1 = ax2.contour(xx, yy, gmm1_z, levels=np.array([.001,.005,.01,.02,.03,.05,.08]), colors='k')
    cntr2 = ax2.contour(xx, yy, gmm2_z, levels=np.array([.001,.005,.01,.02,.03,.05,.08]), colors='red')
    h1,_ = cntr1.legend_elements()
    h2,_ = cntr2.legend_elements()
    ax2.legend([h1[0], h2[0]], ['p(x)', 'q(x)'],fontsize=30,loc='upper left')
    plt.axis("tight")
    plt.show()
    return fig2

if __name__ == "__main__":
    ################################## Huber Counter Example ###########################
    K = 50
    Ns = [0,1,2,3,4]#,5,6,7,8,9,10
    TrueEnt, TrueCrossEnt, TaylorEnt, TaylorLimit, TaylorCrossEnt, TaylorCrossLimit, fig_list = KLDivergeExample(K,Ns)

    import imageio
    imageio.mimsave('PDF.gif', fig_list, fps=2)

    ####################################### Entropy ###############################################
    x1 = np.linspace(-3,3,K)
    x2 =np.linspace(0,1,K)+.01
    xs = x1
    fig1 = go.Figure([
            go.Scatter(
                x=xs,
                y=TrueEnt.flatten(),
                line=dict(color='rgb(255,0,0)', width=3),
                mode='lines',
                name='True Entropy'
            )])

    for i in range(len(Ns))[::2]:
        D = 'rgb(0,%d,0)'%(i*155/(len(Ns)-1)+100)
        if i == Ns[-1]:
            fig1.add_trace(
                    go.Scatter(
                        x=xs,
                        y=TaylorEnt[:,i].flatten(),
                        line=dict(color=D, width=3),
                        mode='lines',
                        name='Our Method'))
            fig1.add_trace(
                    go.Scatter(
                        x=xs,
                        y=TaylorLimit.flatten(),
                        line=dict(color='rgb(255,0,255)', width=3),
                        mode='lines',
                        name='Approx. Limit'))

        else:
            fig1.add_trace(
                    go.Scatter(
                        x=xs,
                        y=TaylorEnt[:,i].flatten(),
                        line=dict(color=D, width=3),
                        mode='lines',
                        showlegend=False))
    fig1.update_xaxes(title_text="c")#"5th Gaussian Component Mean", type="log", dtick = "D2"
    fig1.update_yaxes(title_text="Entropy")#, type="log", dtick = 1
    #fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
    fig1.update_layout(font=dict(size=25))#,legend=dict(yanchor="bottom", y=0.01, xanchor="right", x=0.95),showlegend=False
    fig1.update_layout(plot_bgcolor='white')
    fig1.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    fig1.update_yaxes(
        # range = [-1.5,2],
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    # fig1.write_image("Entropy.pdf")
    fig1.show()


    ################################## CROSS ENTROPY ######################################################
    fig2 = go.Figure([
            go.Scatter(
                x=xs,
                y=TrueCrossEnt.flatten(),
                line=dict(color='rgb(255,0,0)', width=3),
                mode='lines',
                name='True Entropy'
            )])

    for i in range(len(Ns))[::2]:
        D = 'rgb(0,%d,0)'%(i*155/(len(Ns)-1)+100)
        if i == Ns[-1]:
            fig2.add_trace(
                    go.Scatter(
                        x=xs,
                        y=TaylorCrossEnt[:,i].flatten(),
                        line=dict(color=D, width=3),
                        mode='lines',
                        name='Our Method'))
            fig2.add_trace(
                    go.Scatter(
                        x=xs,
                        y=TaylorCrossLimit.flatten(),
                        line=dict(color='rgb(255,0,255)', width=3),
                        mode='lines',
                        name='Approx. Limit'))

        else:
            fig2.add_trace(
                    go.Scatter(
                        x=xs,
                        y=TaylorCrossEnt[:,i].flatten(),
                        line=dict(color=D, width=3),
                        mode='lines',
                        showlegend=False))
    fig2.update_xaxes(title_text="c")#"5th Gaussian Component Mean", type="log", dtick = "D2"
    fig2.update_yaxes(title_text="Cross Entropy")#, type="log", dtick = 1
    #fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
    fig2.update_layout(font=dict(size=25))#,legend=dict(yanchor="bottom", y=0.01, xanchor="right", x=0.95),showlegend=False
    fig2.update_layout(plot_bgcolor='white')
    fig2.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    fig2.update_yaxes(
        # range = [-1.5,2],
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )
    # fig1.write_image("CrossEntropy.pdf")
    fig2.show()


    ################################ KL Divergence ###########################################
    import plotly.io as pio
    fig_list1 = []
    for k in range(K):
        fig3 = go.Figure([
                go.Scatter(
                    x=xs,
                    y=TrueCrossEnt.flatten()-TrueEnt.flatten(),
                    line=dict(color='rgb(255,0,0)', width=3),
                    mode='lines',
                    name='True Entropy'
                )])

        for i in range(len(Ns))[::2]:
            D = 'rgb(0,%d,0)'%(i*155/(len(Ns)-1)+100)
            if i == Ns[-1]:
                fig3.add_trace(
                        go.Scatter(
                            x=xs,
                            y=TaylorCrossEnt[:,i].flatten()-TaylorEnt[:,i].flatten(),
                            line=dict(color=D, width=3),
                            mode='lines',
                            name='Our Method'))
                fig3.add_trace(
                        go.Scatter(
                            x=xs,
                            y=TaylorCrossLimit.flatten()-TaylorLimit.flatten(),
                            line=dict(color='rgb(255,0,255)', width=3),
                            mode='lines',
                            name='Approx. Limit'))
                fig3.add_trace(
                        go.Scatter(
                            x=[xs[k],xs[k]],
                            y=[-3.5,2.5],
                            line=dict(color='rgb(0,0,0)', width=3),
                            mode='lines',
                            name='PDF plot'))

            else:
                fig3.add_trace(
                        go.Scatter(
                            x=xs,
                            y=TaylorCrossEnt[:,i].flatten()-TaylorEnt[:,i].flatten(),
                            line=dict(color=D, width=3),
                            mode='lines',
                            showlegend=False))
        fig3.update_xaxes(title_text="c")#"5th Gaussian Component Mean", type="log", dtick = "D2"
        fig3.update_yaxes(title_text="KL Divergence")#, type="log", dtick = 1
        #fig1.update_layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)')
        fig3.update_layout(font=dict(size=25))#,legend=dict(yanchor="bottom", y=0.01, xanchor="right", x=0.95),showlegend=False
        fig3.update_layout(plot_bgcolor='white')
        fig3.update_xaxes(
            mirror=True,
            ticks='outside',
            showline=True,
            linecolor='black',
            gridcolor='lightgrey'
        )
        fig3.update_yaxes(
            range = [-3.5,2.5],
            mirror=True,
            ticks='outside',
            showline=True,
            linecolor='black',
            gridcolor='lightgrey'
        )
        fig3.show()
        # convert the figure to an image and add it to the list
        img_bytes = pio.to_image(fig3, format='png')
        img = imageio.imread(img_bytes)
        fig_list1.append(img)
        
    # imageio.mimsave('KLDiv.gif', fig_list1, fps=2)