# -*- coding: utf-8 -*-
# @date: 20220429

import numpy as np
from scipy.io import savemat, loadmat
from typing import List, Dict, Tuple, Any, Union

"""
Utilities for the implementation of MP
"""

numeric = np.float32
OPT_METHOD_SGD = "SGD"
OPT_METHOD_SPL = "SPL"


def get_matlab_pr_data(fname: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    mat_data = loadmat(fname)
    A = np.float32(mat_data['data']['A'][0][0])
    b = np.float32(mat_data['data']['b'][0][0].astype(numeric))
    x = np.float32(mat_data['data']['optx'][0][0].astype(numeric))
    return A, b, x


def gen_pr_data(m: int, n: int):
    m_sample = int(m)
    opt_x = np.random.rand(n)
    opt_x /= np.linalg.norm(opt_x)
    opt_x = numeric(opt_x)
    A = numeric(np.random.randn(m_sample, n))
    b = numeric((A @ opt_x) ** 2)
    return A, b, opt_x


def proj_onto_unit_box(x: numeric) -> numeric:
    return numeric(np.maximum(np.minimum(x, 1.0), -1.0))


def assemble_sub_gradient(a: np.ndarray, b: numeric, x: np.ndarray) -> np.ndarray:
    aTx = 0.0
    g = np.zeros_like(x, dtype=numeric)
    for i in range(a.size):
        aTx += a[i] * x[i]
    fval = aTx ** 2 - b
    coeff = np.sign(fval)
    if coeff == 0.0:
        coeff = 2 * np.random.rand() - 1
    for i in range(a.size):
        g[i] = 2.0 * coeff * aTx * a[i]
    # aTx = np.dot(a, x)
    # fval = aTx ** 2 -
    # coeff = np.sign(fval)
    # if coeff == 0.0:
    #     coeff = 2 * np.random.rand() - 1
    # g = 2.0 * coeff * aTx * a
    return g


def assemble_sub_gradient_blind(uv: np.ndarray, b: numeric, z: np.ndarray) -> np.ndarray:
    n = int(z.size / 2)
    uTx = 0.0
    vTy = 0.0
    for i in range(n):
        uTx += uv[i] * z[i]
    for i in range(n):
        vTy += uv[i + n] * z[i + n]
    fval = uTx * vTy - b
    coeff = np.sign(fval)
    if coeff == 0.0:
        coeff = 2 * np.random.rand() - 1
    g = np.zeros_like(z, dtype=numeric)
    for i in range(n):
        g[i] = coeff * vTy * uv[i]
    for i in range(n):
        g[i + n] = coeff * uTx * uv[i + n]
    # uTx = np.dot(uv[0:n], z[0:n])
    # vTy = np.dot(uv[n:], z[n:])
    # fval = uTx * vTy - b
    # coeff = np.sign(fval)
    # if coeff == 0.0:
    #     coeff = 2 * np.random.rand() - 1
    # g = np.concatenate([vTy * uv[0:n], uTx * uv[n:]]) * coeff
    return g


def assemble_delta_zeta(a: np.ndarray, b: numeric, xy: np.ndarray) -> np.ndarray:
    raise DeprecationWarning("assemble_delta_zeta is depreciated")
    if xy.size == a.size:
        aTx = np.dot(a, xy)
        delta = aTx ** 2 - b
        zeta = 2 * aTx * a
    else:
        n = a.size
        aTx = np.dot(a, xy[0:n])
        delta = \
            aTx ** 2 - b + 2 * aTx * np.dot(a, (xy[n:] - xy[0:n]))
        zeta = 2 * aTx * a

    return np.concatenate([[delta], zeta])


def assemble_obj_grad(a: np.ndarray, b: numeric, x: np.ndarray) -> np.ndarray:
    aTx = 0.0
    grad = np.zeros_like(x, dtype=numeric)
    for i in range(a.size):
        aTx += a[i] * x[i]
    obj = aTx ** 2 - b
    for i in range(a.size):
        grad[i] = 2 * aTx * a[i]
    # aTx = np.dot(a, x)
    # obj = aTx ** 2 - b
    # grad = 2 * aTx * a
    return np.concatenate([[obj], grad])


def assemble_obj_grad_blind(uv: np.ndarray, b: numeric, z: np.ndarray) -> np.ndarray:
    n = int(z.size / 2)
    uTx = 0.0
    vTy = 0.0
    for i in range(n):
        uTx += uv[i] * z[i]
    for i in range(n):
        vTy += uv[i + n] * z[i + n]
    grad = np.zeros_like(z, dtype=numeric)
    for i in range(n):
        grad[i] = vTy * uv[i]
    for i in range(n):
        grad[i + n] = uTx * uv[i + n]
    obj = uTx * vTy - b
    # uTx = np.dot(uv[0:n], z[0:n])
    # vTy = np.dot(uv[n:], z[n:])
    # obj = uTx * vTy - b
    # return np.concatenate([[obj], vTy * uv[0:n], uTx * uv[n:]])
    return np.concatenate([[obj], grad])


def pr_obj(A: np.asarray, b: np.asarray, x: np.ndarray) -> float:
    return np.sum(np.abs(np.power(A.dot(x), 2) - b)) / b.size


def bl_obj(UV: np.asarray, b: np.asarray, x: np.ndarray) -> float:
    n_x = int(x.size / 2)
    return np.sum(np.abs(
        np.multiply(UV[:, 0:n_x] @ x[0:n_x],
                    UV[:, n_x:] @ x[n_x:])
        - b)) / b.size


def init_pr(n: int, alg, is_master=True) -> np.ndarray:
    raise DeprecationWarning("init_pr is now depreciated")
    x = np.random.rand(n)
    x /= np.linalg.norm(x)
    if alg == OPT_METHOD_SPL:
        x = np.concatenate([x, x])
    return x * float(is_master)


def init_pr_2(n: int, alg, is_master=True) -> np.ndarray:
    x = np.random.rand(n)
    x /= np.linalg.norm(x)
    return x * float(is_master)


def init_bl(n: int, alg, is_master=True) -> np.ndarray:
    n_x = int(n / 2)
    z = np.random.rand(n)
    z[0:n_x] /= np.linalg.norm(z[0:n_x])
    z[n_x:] /= np.linalg.norm(z[n_x:])
    return z * float(is_master)