# -*- coding: utf-8 -*-
# @date: 20220429

"""
An implementation of SPL and SGD with Momentum to solve Phase Retrieval Problem

"""
import time
from utils import *
from opt import mpOpt
from mpi4py import MPI
import argparse


ROOT = 0
DONE = 999999
NOT_DONE = 1

COMM = MPI.COMM_WORLD
SIZE = MPI.COMM_WORLD.Get_size()
RANK = MPI.COMM_WORLD.Get_rank()


class prOpt(mpOpt):

    def __init__(self, n_dim: int, gamma: numeric,
                 momentum: numeric = 0.0,
                 algorithm=OPT_METHOD_SGD):
        super(prOpt, self).__init__(n_dim=n_dim,
                                    gamma=gamma,
                                    momentum=momentum)
        self.algorithm = algorithm

    def iterate(self, g: np.ndarray) -> None:
        """
        :param g: Gradient components.
        For SGD method, g is of size  n
        For SPL method, g is of size  n + 1
        :return: None
        """
        if self.algorithm == OPT_METHOD_SPL:
            delta = g[0] / self._gamma
            zeta  = g[1:] / self._gamma
            coeff = proj_onto_unit_box(- delta / (np.linalg.norm(zeta)**2))
            self._x = self._y + coeff * zeta
        elif self.algorithm == OPT_METHOD_SGD:
            self._x = self._y - g / self._gamma
        else:
            raise NotImplementedError("Algorithm is not implemented")

        # Update momentum term
        self._y = self._x + self._beta * (self._x - self._x_old)
        self._x_old = self._x

        return

    def sync_opt(self, epoch: int, A_data: np.ndarray,
                 b_data: np.ndarray, alg):
        """
        Carry out synchronous optimization
        :param epoch:  Number of epochs
        :param A_data: Phase retrieval data A
        :param b_data: Phase retrieval data b
        :param alg:    Phase retrieval Algorithm
        :return:
        """
        m, n = A_data.shape
        grad_oracle = assemble_sub_gradient \
            if alg == OPT_METHOD_SGD else assemble_delta_zeta
        fobjs = np.zeros((epoch + 1, 1))
        fobjs[0] = pr_obj(A_data, b_data, self._x)
        print("%10s  %10s" % ("Epoch", "Obj"))
        print("%10d  %10e" % (0, fobjs[0]))
        for k in range(epoch):
            for i in np.random.permutation(m):
                a = A_data[i]
                b = b_data[i]
                g = grad_oracle(a, b, self.get_x())
                self.iterate(g)
            fobjs[k + 1] = pr_obj(A_data, b_data, self._x)
            print("%10d  %10e" % (k + 1, fobjs[k + 1]))

    def get_x(self) -> np.ndarray:
        if self.algorithm == OPT_METHOD_SPL:
            return np.concatenate([self._x, self._y])
        else:
            return self._x


def train_sync_master(n_iter: int, opt: prOpt):
    d = opt.n
    peers = list(range(SIZE))
    peers.remove(SIZE)
    ave_g = np.zeros(d, dtype=numeric)
    gs = np.empty((SIZE, d), dtype=numeric)
    for i in range(n_iter):
        g = np.zeros(d, dtype=numeric)
        COMM.Gather(g, gs, root=ROOT)
        ave_g = gs[peers].mean(axis=0)
        opt.iterate(ave_g)
        COMM.Bcast(opt.get_x(), root=ROOT)


def train_sync_worker(n_iter: int, As: np.ndarray, bs: np.ndarray,
                      x: np.ndarray, alg=OPT_METHOD_SGD) -> np.ndarray:
    d = As.shape[1] if alg == OPT_METHOD_SGD else As.shape[1] + 1
    g = np.empty(d, dtype=numeric)
    gs = None

    for i in np.random.permutation(n_iter):
        a = As[i]
        b = bs[i]
        if alg == OPT_METHOD_SGD:
            g = assemble_sub_gradient(a, b, x)
        else:
            g = assemble_delta_zeta(a, b, x)
        assert g.dtype == numeric
        COMM.Gather(g, gs, root=ROOT)
        COMM.Bcast(x, root=ROOT)

    return x


def train_async_master(n_iter: int, opt: prOpt):
    d = opt.n if opt.algorithm == OPT_METHOD_SGD else opt.n + 1
    peers = list(range(SIZE))
    peers.remove(ROOT)
    n_peers = len(peers)

    if RANK == ROOT:
        gg = np.empty((n_peers, d), dtype=numeric)

    requests = [MPI.REQUEST_NULL for i in peers]

    for i in range(n_peers):
        requests[i] = COMM.Irecv(gg[i], source=peers[i])

    n_master_rcv_epoch = 0
    n_active_workers = n_peers

    while n_active_workers > 0:
        list_received = MPI.Request.Waitsome(requests)
        for i in list_received:
            opt.iterate(gg[i])
            n_master_rcv_epoch += 1
            if n_master_rcv_epoch < n_iter:
                COMM.Send(opt.get_x(), dest=peers[i], tag=NOT_DONE)
                requests[i] = COMM.Irecv(gg[i], source=peers[i])
            else:
                COMM.Send(opt.get_x(), dest=peers[i], tag=DONE)
                n_active_workers -= 1


def train_async_worker(n_iter: int, As: np.ndarray, bs: np.ndarray,
                       x: np.ndarray, alg=OPT_METHOD_SGD) -> np.ndarray:
    m = As.shape[0]
    d = As.shape[1] if alg == OPT_METHOD_SGD else As.shape[1] + 1
    g = np.empty(d, dtype=numeric)
    info = MPI.Status()
    info.tag = NOT_DONE
    idx = 0
    perm = np.random.permutation(m)

    while info.tag == NOT_DONE:
        if idx >= m:
            perm = np.random.permutation(m)
            idx = 0

        a = As[perm[idx]]
        b = bs[perm[idx]]
        idx += 1

        if alg == OPT_METHOD_SGD:
            g = assemble_sub_gradient(a, b, x)
        else:
            g = assemble_delta_zeta(a, b, x)

        if g.dtype != numeric:
            g = g.astype(numeric)

        COMM.Send(g, dest=ROOT)
        COMM.Recv(x, source=ROOT, tag=MPI.ANY_TAG, status=info)

    return x


def get_pr_data(m: int, n: int):
    if SIZE == 1:
        m_sample = m
    else:
        m_sample = int(m / (SIZE - 1))

    if RANK == ROOT:
        opt_x = np.random.rand(n)
        opt_x /= np.linalg.norm(opt_x)
        opt_x = numeric(opt_x)
    else:
        opt_x = np.empty(n, dtype=numeric)
    COMM.Bcast(opt_x, root=ROOT)

    if RANK == ROOT and SIZE != 1:
        A = None
        b = None
    else:
        A = np.random.randn(m_sample, n)
        A = A.astype(numeric)
        b = (A @ opt_x) ** 2
        b = b.astype(numeric)

    return A, b, opt_x


parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=50000)
parser.add_argument("--n", type=int, default=20000)

parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--beta", type=float, default=0.0)
parser.add_argument("--alg", type=str, default="SGD")


def main():

    # Parse arguments
    args = parser.parse_args()
    m = args.m
    n = args.n
    alpha_0 = args.alpha
    beta = args.beta
    alg = args.alg

    if RANK == ROOT:
        print("Test begins on {0}".format(time.asctime()))
        print("Parameters m: %d  n: %d  alpha: %f  beta: %f  alg: %s" %
              (m, n, alpha_0, beta, alg))

    # Set parameters
    sync = False
    seed = RANK + 20220510
    np.random.seed(seed)
    epoch = 0
    n_epoch = 100
    n_iter_epoch = m
    n_iter_all = n_epoch * m
    gamma = np.sqrt(n_iter_all) / alpha_0

    # Generate data
    A_data, b_data, opt_x = get_pr_data(m, n)

    if sync:
        x = init_pr(n, alg, True)
        pr_opt = prOpt(n, gamma, numeric(beta), alg)
        pr_opt.initialize(x[0:n])
        pr_opt.sync_opt(n_epoch, A_data, b_data, alg)
        return

    COMM.Barrier()

    # Initial point
    if RANK == ROOT:
        x = init_pr(n, alg, True)
    else:
        x = init_pr(n, alg, False)

    if x.dtype != numeric:
        x = x.astype(numeric)
    COMM.Bcast(x, root=ROOT)

    pr_opt = prOpt(n, gamma, numeric(beta), alg)
    pr_opt.initialize(x[0:n])

    if RANK == ROOT:
        output_stream = "res_job_{6}_epoch_{0}_alpha_{1}_beta_{2}_m_{3}_n_{4}_{5}.txt"\
                        "".format(n_epoch, alpha_0, beta * 10, m, n, alg, SIZE - 1)

        f = open(output_stream, "a")
        output_str = "%12s  %12s  %12s  %12s" % ("Epoch", "||x-x^*||", "f - f^*", "Time")

        print(output_str, flush=True)
        f.write(output_str + "\n")

    if RANK == ROOT:
        total_time = 0.0
        time_start = time.time()
        test_time = time.time()
        peers = list(range(SIZE))
        peers.remove(ROOT)
        obj_gap = np.mean(0.0).astype(numeric)
        obj_gaps = np.empty(SIZE, dtype=numeric)
        COMM.Gather(obj_gap, obj_gaps, root=ROOT)
        obj_gaps_avg = obj_gaps[peers].mean(axis=0)
        nrm = np.minimum(np.linalg.norm(x[0:n] - opt_x),
                         np.linalg.norm(x[0:n] + opt_x))
        total_time += time.time() - test_time

        output_str = "%12d  %12e  %12e  %12e" % (epoch,
                                                 nrm,
                                                 obj_gaps_avg,
                                                 time.time() - time_start - total_time)
        f.write(output_str + "\n")
        print(output_str, flush=True)
    else:
        obj_gap = numeric(pr_obj(A_data, b_data, x[0:n]))
        obj_gaps = None
        COMM.Gather(obj_gap, obj_gaps, root=ROOT)

    while epoch < n_epoch:

        if RANK == ROOT:
            train_async_master(n_iter_epoch, pr_opt)
        else:
            x = train_async_worker(n_iter_epoch, A_data, b_data, x, alg)
        epoch += 1

        if RANK == ROOT:
            x = pr_opt.get_x()
        COMM.Bcast(x, root=ROOT)

        if RANK == ROOT:
            test_time = time.time()
            COMM.Gather(obj_gap, obj_gaps, root=ROOT)
            obj_gaps_avg = obj_gaps[peers].mean(axis=0)
            nrm = np.minimum(np.linalg.norm(x[0:n] - opt_x),
                             np.linalg.norm(x[0:n] + opt_x))
            total_time += time.time() - test_time

            output_str = "%12d  %12e  %12e  %12e" % (epoch,
                                                     nrm,
                                                     obj_gaps_avg,
                                                     time.time() - time_start - total_time)
            f.write(output_str + "\n")
            print(output_str, flush=True)
        else:
            obj_gap = numeric(pr_obj(A_data, b_data, x[0:n]))
            COMM.Gather(obj_gap, obj_gaps, root=ROOT)


if __name__ == '__main__':
    main()
