import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
import pickle
import os
from datetime import datetime

#Generate inverted measurements
def get_meas(D, m, n, tau, Sig, y = 200, eta_up = 10):
    #Sensing matrices
    As = np.zeros((D, D, n))

    #gamma values after m-avg and truncation
    gammas = []
    for i in range(n):
        #Sample a from multivariate normal, take quadratic measurement
        a_i = np.random.normal(size=(D))
        quad_i = a_i.T @ Sig @ a_i

        #Get noisy distance
        y_tilde = y + np.random.uniform(low=-eta_up, high=eta_up, size=m)

        #Get m samples of noisy distance, average, then truncate
        gamma_sq = y_tilde / quad_i
        gamma_tilde = np.minimum(tau, gamma_bar)

        #Store sensing matrix and gammas
        As[:,:,i] = gamma_tilde*np.outer(a_i, a_i)
        gammas.append(gamma_tilde)

    gammas_all = np.array(gammas_all)

    return As, gammas

#Generate ground truth Sigma and tau for outer product of orthonormal matrices
def get_gt_orth(D, r, y, eta_up, n):
    L = np.random.randn(D, r)
    Q, _ = np.linalg.qr(L)
    Sig = (D / np.sqrt(r)) * Q @ Q.T

    singular_vals = np.sort(np.linalg.svd(Sig)[1])

    if r > 8:
        M_14 = ((y + eta_up) / (4*singular_vals[-r]*(r-8)))
    else:
        M_14 = ((y + eta_up) / (4*singular_vals[-r]*r))

    tau = M_14*np.sqrt(n / D)

    return Sig, tau

#Get reg. parameter lambda based on theory, with tuned constant C
def get_lda(D, n, m, r, C):
    return C*(np.sqrt(D/n) + 1/m)/r


#Solve estimation problem
def get_est(A_flat, y, lda, D, n, max_iter=10000):
    Sig_hat = cp.Variable((D, D), PSD=True)

    loss = cp.sum_squares(y - cp.vec(Sig_hat) @ A_flat) / n + lda*cp.norm(Sig_hat, 'nuc')
    obj = cp.Minimize(loss)
    prob = cp.Problem(obj)

    prob.solve(solver=cp.SCS, max_iters=max_iter)

    return Sig_hat.value

def save_exp(norm_err_sweep, x_range, type_plot, params, filename):
    exp_out = {
        'norm_err' : norm_err_sweep,
        'x_range' : x_range,
        'sweep_type' : type_plot,
        'params' : params
    }

    pickle.dump(exp_out, open(filename + '.pkl', 'wb'))
