import torch
import math
from torch import nn
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

d = 25
k = 10

# groundtruth
w_star = torch.ones(d, 1) / math.sqrt(d)
# a1 = torch.zeros(k, 1)
# a1[0, 0] = 1
# a_star = a1
a_star = torch.ones(k, 1) / math.sqrt(k)


# auxiliary functions
def acos_safe(x, eps=1e-7):
    sign = torch.sign(x)
    slope = np.arccos(1-eps) / eps
    return torch.where(abs(x) <= 1-eps,
                    torch.acos(x),
                    torch.acos(sign * (1 - eps)) - slope*sign*(abs(x) - 1 + eps))

def angle(x1, x2):
    w_cos = torch.sum(x1 * x2) / (torch.norm(x1) * torch.norm(x2))
    eps = 1e-7
    return acos_safe(w_cos)

# gradients with closed form
def closed_grad(v, a):
    phi = angle(v, w_star)
    g_phi = (math.pi - phi) * torch.cos(phi) + torch.sin(phi)
    grad_v = - torch.sum(a * a_star) * (math.pi - phi) / (2 * math.pi * torch.norm(v)) * torch.matmul(torch.eye(d) - torch.matmul(v, torch.t(v)) / torch.norm(v) ** 2, w_star)
    grad_a = 1 / (2 * math.pi) * torch.matmul(torch.ones(k, k) + (math.pi - 1) * torch.eye(k), a) - torch.norm(w_star) / (2 * math.pi) * torch.matmul(torch.ones(k, k) + (g_phi - 1) * torch.eye(k), a_star)
    return grad_v, grad_a

# loss of cnn
def cnn_loss(phi, a):
    g_phi = (math.pi - phi)*torch.cos(phi) + torch.sin(phi)
    loss_1 = torch.sum(torch.matmul(torch.ones(k, k) + (math.pi - 1) * torch.eye(k), a) * a)/(2*math.pi)
    loss_2 = -torch.sum(torch.matmul(torch.ones(k, k) + (g_phi - 1) * torch.eye(k), a_star) * a)/math.pi
    loss_3 = torch.sum(torch.matmul(torch.ones(k, k) + (math.pi - 1) * torch.eye(k), a_star) * a_star)/(2*math.pi)
    return (loss_1+loss_2+loss_3)/2

# stochastic gradients
def stoc_grad(v, a, bs):
	Z = torch.randn(bs, k, d)
	relu = torch.nn.ReLU()

	w = v / torch.norm(v) #normalization
	w_relu = relu(torch.matmul(Z, w)) #sgn(sigma(Zw))
	w_relu_sgn= torch.sgn(w_relu)
	wstar_relu = relu(torch.matmul(Z, w_star)) #sgn(sigma(Zw*))
	wstar_relu_sgn = torch.sgn(wstar_relu)

	w_relu_a = torch.matmul(torch.transpose(Z, 1, 2), w_relu_sgn * a)
	wstar_relu_astar = torch.matmul(torch.transpose(Z, 1, 2), wstar_relu_sgn * a_star)

	Pv = (torch.eye(d) - torch.matmul(w, torch.t(w))) / torch.norm(v) # projection matrix

	# gradient of v
	grad_w = -torch.matmul(torch.matmul(w_relu_a, torch.transpose(wstar_relu_astar, 1, 2)), w_star)
	grad_w = torch.mean(grad_w, 0)
	grad_v = torch.matmul(Pv, grad_w)

	# gradient of a
	grad_a = torch.matmul(torch.transpose(w_relu, 1, 2), a) * w_relu - torch.matmul(torch.transpose(wstar_relu, 1, 2), a_star) * w_relu
	grad_a = torch.mean(grad_a, 0)
	return grad_v, grad_a

# random initial points
def weights_initial(d = d, k = k):
    # initial points
    w_ini = torch.randn(d, 1)
    w_ini /= torch.norm(w_ini)
    a_ini = torch.randn(k, 1)
    a_ini /= torch.norm(a_ini)
    a_ini = a_ini*torch.rand(1)*torch.abs(torch.sum(torch.ones(k, 1)*a_star))/math.sqrt(k)
    return w_ini, a_ini

## hyper parameters
lr = 0.2
bs = 4


## mini-batch SGD
def SGD_cnn(v0, a0, K, R, N = 8, bs = bs, eta = lr):
	vt, at = v0, a0
	v_cos = []
	a_prod = []
	loss = []
	v0_cos = torch.sum(v0 * w_star) / (torch.norm(v0) * torch.norm(w_star))
	v_cos.append(v0_cos)
	a_prod.append(torch.sum(a0*a_star))
	v0_angle = acos_safe(v0_cos)
	loss.append(cnn_loss(v0_angle, a0))

	print("\nMinibatch SGD begins!\n")
	for t in range(R):
		v_grad, a_grad = stoc_grad(vt, at, bs*K*N) # compute stochstic gradeints
		vt = vt - eta * v_grad
		at = at - eta * a_grad

		vt_cos = torch.sum(vt * w_star) / (torch.norm(vt) * torch.norm(w_star))
		v_cos.append(vt_cos)
		a_prod.append(torch.sum(at*a_star))
		vt_angle = acos_safe(vt_cos)
		loss.append(cnn_loss(vt_angle, at))
	return v_cos, a_prod, loss

## Local SGD
def localSGD(v0, a0, K, R, N = 8, bs = bs, eta = lr):
	vr, ar = v0, a0
	v_cos = []
	loss = []
	a_prod = []
	v0_cos = torch.sum(v0 * w_star) / (torch.norm(v0) * torch.norm(w_star))
	v_cos.append(v0_cos)
	a_prod.append(torch.sum(a0*a_star))
	v0_angle = acos_safe(v0_cos)
	loss.append(cnn_loss(v0_angle, a0))

	print("\nLocal SGD begins!\n")
	# local updates
	for r in range(R):
		vr_list = torch.zeros(N, d, 1)
		ar_list = torch.zeros(N, k, 1)
		for i in range(N):
			vt = vr
			at = ar
			for t in range(K-1):
				v_grad, a_grad = stoc_grad(vt, at, bs) # compute stochstic gradeints
				vt = vt - eta * v_grad
				at = at - eta * a_grad
			vr_list[i] = vt
			ar_list[i] = at
		
		# update global weights
		vr = torch.mean(vr_list, 0)
		ar = torch.mean(ar_list, 0)
		vr_cos = torch.sum(vr * w_star) / (torch.norm(vr) * torch.norm(w_star))
		v_cos.append(vr_cos)
		a_prod.append(torch.sum(ar * a_star))
		vt_angle = acos_safe(vr_cos)
		loss.append(cnn_loss(vt_angle, ar))
	return v_cos, a_prod, loss

## simulation
K = 8
R = 100
T = K*R

a0 = torch.zeros(k, 1)

a0 = - a_star
v0 = torch.zeros(d, 1)
v0[0, 0] = -1
#v0[1, 0] = -1
#v0, a0 = weights_initial(d, k)

v_cos, a_prod, loss = localSGD(v0, a0, K, R)
SGD_cos, SGD_a_prod, SGD_loss = SGD_cnn(v0, a0, K, R)

fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(3.5, 6.5), sharex = True)

ax1.plot(v_cos, label = "Local SGD")
ax1.plot(SGD_cos, label = "Minibatch SGD")
ax1.set_ylabel('$\cos\phi_t$', fontsize = 15)
#ax1.set_title('$\cos\phi_t$', fontsize = 15, loc = 'left')
#ax1.legend(prop={'size': 12})
ax2.plot(a_prod)
ax2.plot(SGD_a_prod)
ax2.set_ylabel('$a_t^{T}a^*$', fontsize = 15)
#ax2.set_title("$a_t^{T}a^*$", fontsize = 15, loc = 'left')
ax3.plot(loss)
ax3.plot(SGD_loss)
#ax3.set_title("loss", fontsize = 15, loc = 'left')
ax3.set_ylabel("loss", fontsize = 15)
ax3.set_xlabel("Rounds", fontsize = 15)
fig.tight_layout()
plt.show()
