import numpy as np
from scipy.stats import ortho_group
from matplotlib import pyplot as plt

plt.rcParams["font.size"] = 16
plt.rcParams["legend.fontsize"] = 16
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["lines.linewidth"] = 3


def generate_A(dim, seed=123):
    rank = dim - 1
    np.random.seed(seed)

    eig_non0 = np.random.uniform(low=0.1, high=10, size=rank)
    eig = np.hstack((eig_non0, np.zeros(dim-rank)))

    P = ortho_group.rvs(dim, random_state=seed)
    A = P.T @ np.diag(eig) @ P

    return A


def proj(x, a=1000):
    n2 = np.linalg.norm(x)
    if n2 <= a:
        return x
    else:
        return x * (a/n2)


def oracle(A, x, y, x_0, y_0, reg=0.0, inexact=0.0):
    grad_x = A @ y
    grad_y = A.T @ x
    gap = np.linalg.norm(grad_y) + np.linalg.norm(grad_x)

    grad_x +=  reg * (x - x_0) + inexact * np.ones_like(grad_x)
    grad_y += -reg * (y - y_0) + inexact * np.ones_like(grad_y)
    return gap, grad_x, grad_y


def norm2(x, y):
    nx = np.linalg.norm(x)
    ny = np.linalg.norm(y)
    return np.sqrt(nx*nx + ny*ny)


def gda(A, x_0, y_0, step, T, reg=0.0, inexact=0.0):
    x_bar_list, y_bar_list = [x_0], [y_0]
    x, y = x_0, y_0
    x_bar, y_bar = x_0, y_0
    x_list, y_list = [x_0], [y_0]

    for t in range(T):
        _, g_x, g_y = oracle(A, x, y, x_0, y_0, reg, inexact)
        x = proj(x - step * g_x)
        y = proj(y + step * g_y)

        x_bar = (x_bar * (t+1) + x) / (t+2)
        y_bar = (y_bar * (t+1) + y) / (t+2)

        x_bar_list.append(x_bar)
        y_bar_list.append(y_bar)
        x_list.append(x)
        y_list.append(y)
    
    if reg == 0:
        return x_bar_list, y_bar_list
    else:
        return x_list, y_list


def eg(A, x_0, y_0, step, T, reg=0.0, inexact=0.0):
    x_bar_list, y_bar_list = [x_0], [y_0]
    x, y = x_0, y_0
    x_bar, y_bar = x_0, y_0
    x_list, y_list = [x_0], [y_0]

    for t in range(T):
        _, g_x, g_y = oracle(A, x, y, x_0, y_0, reg, inexact)
        x_ = proj(x - step * g_x)
        y_ = proj(y + step * g_y)

        x_bar = (x_bar * (t+1) + x_) / (t+2)
        y_bar = (y_bar * (t+1) + y_) / (t+2)
        x_bar_list.append(x_bar)
        y_bar_list.append(y_bar)

        _, g_x, g_y = oracle(A, x_, y_, x_0, y_0, reg, inexact)
        x = proj(x - step * g_x)
        y = proj(y + step * g_y)

        x_list.append(x)
        y_list.append(y)
    
    if reg == 0:
        return x_bar_list, y_bar_list
    else:
        return x_list, y_list


def run_minimax_grad(n, T, delta, reg, step_gda, step_eg, seed=123):
    A = generate_A(n, seed)
    x = y = np.ones((n,1))

    print('stepsize used in theory for GDA:', 1 / (np.sqrt(T) * np.linalg.norm(A, 2)))
    print('stepsize used in theory for EG:', 1 / np.linalg.norm(A, 2))
    print('stepsize used in theory for Reg-GDA:', 0.25 * reg / ((np.linalg.norm(A, 2) + reg) ** 2))
    print('stepsize used in theory for Reg-EG:', 0.5 / (np.linalg.norm(A, 2) + reg))

    x_gda_1, y_gda_1 = gda(A, x, y, step_gda[0], T)
    x_eg_1, y_eg_1 = eg(A, x, y, step_eg[0], T)
    x_gda_2, y_gda_2 = gda(A, x, y, step_gda[0], T, inexact=delta)
    x_eg_2, y_eg_2 = eg(A, x, y, step_eg[0], T, inexact=delta)

    x_reg_gda_1, y_reg_gda_1 = gda(A, x, y, step_gda[1], T, reg=reg)
    x_reg_eg_1, y_reg_eg_1 = eg(A, x, y, step_eg[1], T, reg=reg)
    x_reg_gda_2, y_reg_gda_2 = gda(A, x, y, step_gda[1], T, inexact=delta, reg=reg)
    x_reg_eg_2, y_reg_eg_2 = eg(A, x, y, step_eg[1], T, inexact=delta, reg=reg)

    colorlist = ['#448bff', '#3bc335', '#ff9600', '#f84c00']

    fig = plt.figure(figsize=(12, 5))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)

    ax1.plot(list(range(T+1)), [oracle(A, i, j, x, y)[0] for i,j in zip(x_eg_2, y_eg_2)], label='EG', color=colorlist[0])
    ax1.plot(list(range(T+1)), [oracle(A, i, j, x, y)[0] for i,j in zip(x_gda_2, y_gda_2)], label='GDA', color=colorlist[1])
    ax1.plot(list(range(T+1)), [oracle(A, i, j, x, y)[0] for i,j in zip(x_reg_eg_2, y_reg_eg_2)], label='Reg-EG', color=colorlist[2])
    ax1.plot(list(range(T+1)), [oracle(A, i, j, x, y)[0] for i,j in zip(x_reg_gda_2, y_reg_gda_2)], label='Reg-GDA', color=colorlist[3])

    diff_eg_x = [i1 - i2 for i1, i2 in zip(x_eg_1, x_eg_2)]
    diff_eg_y = [j1 - j2 for j1, j2 in zip(y_eg_1, y_eg_2)]
    diff_gda_x = [i1 - i2 for i1, i2 in zip(x_gda_1, x_gda_2)]
    diff_gda_y = [j1 - j2 for j1, j2 in zip(y_gda_1, y_gda_2)]

    diff_reg_eg_x = [i1 - i2 for i1, i2 in zip(x_reg_eg_1, x_reg_eg_2)]
    diff_reg_eg_y = [j1 - j2 for j1, j2 in zip(y_reg_eg_1, y_reg_eg_2)]
    diff_reg_gda_x = [i1 - i2 for i1, i2 in zip(x_reg_gda_1, x_reg_gda_2)]
    diff_reg_gda_y = [j1 - j2 for j1, j2 in zip(y_reg_gda_1, y_reg_gda_2)]

    ax2.plot(list(range(T+1)), [norm2(i, j) for i, j in zip(diff_eg_x, diff_eg_y)], label='EG', color=colorlist[0])
    ax2.plot(list(range(T+1)), [norm2(i, j) for i, j in zip(diff_gda_x, diff_gda_y)], label='GDA', color=colorlist[1])
    ax2.plot(list(range(T+1)), [norm2(i, j) for i, j in zip(diff_reg_eg_x, diff_reg_eg_y)], label='Reg-EG', color=colorlist[2])
    ax2.plot(list(range(T+1)), [norm2(i, j) for i, j in zip(diff_reg_gda_x, diff_reg_gda_y)], label='Reg-GDA', color=colorlist[3])
    
    ax1.set_title('Duality Gap')
    ax1.legend()
    ax1.set_xlabel('# Iterations')
    ax1.set_xscale('log')
    ax1.set_yscale('log')

    ax2.set_title('Deviation in Trajectory')
    ax2.legend()
    ax2.set_xlabel('# Iterations')
    ax2.set_xscale('log')
    ax2.set_yscale('log')

    plt.savefig('./result/minimax.png', dpi=500, bbox_inches='tight')