/* coupled_mckean_vlasov.h


   (c) Grant Rotskoff, Carles Domingo-Enrich, 2020.
*/

#ifndef POLY_H
#define POLY_H

#include <armadillo>
#include <math.h>
using namespace arma;

class xsSystem : public CoupledMcKeanVlasov {
  public:

    // compute the interaction between pairs of particles x,y
    virtual double compute_loss_weights(CoupledMcKeanVlasov *ys_system)
    {
      mat &ys = ys_system->xs;
      mat &yws_matrix = ys_system->ws_avg_matrix;
      return -accu(((ws_avg_matrix%xs)*A_1)*(yws_matrix%ys).t()) + trace((xs.t()*(xs%ws_avg_matrix))*A_0) - trace((ys.t()*(ys%yws_matrix))*A_2) 
             + accu((xs%ws_avg_matrix)*a_0) - accu((ys%yws_matrix)*a_1) + accu(((xs%(xs%ws_avg_matrix))*A_3*(ys%yws_matrix).t()));
    }

    virtual double compute_loss_weights_full(CoupledMcKeanVlasov *ys_system, bool avg)
    {
      if (avg) 
      {
        ws_avg_matrix.each_col() = normalise(ws_avg, 1, 0);
        ys_system->ws_avg_matrix.each_col() = normalise(ys_system->ws_avg, 1, 0);
      }
      else
      {
        ws_avg_matrix.each_col() = normalise(ws, 1, 0);
        ys_system->ws_avg_matrix.each_col() = normalise(ys_system->ws, 1, 0);
      }
      return compute_loss_weights(ys_system);
    }

    virtual double compute_loss(CoupledMcKeanVlasov *ys_system)
    {
      mat &ys = ys_system->xs;
      return -accu((xs*A_1)*ys.t())/(n*ys_system->n) + trace((xs.t()*xs)*A_0)/n - trace((ys.t()*ys)*A_2)/(ys_system->n) 
             + accu(xs*a_0)/n - accu(ys*a_1)/(ys_system->n) + accu(((xs%xs)*A_3*ys.t()))/(n*ys_system->n);
    }

    // compute gradients
    virtual void compute_grad(CoupledMcKeanVlasov *ys_system)
    {
      grad_xs = xs*(A_0+A_0.t());
      xs_copy = 2*xs;
      xs_copy.each_row() %= sum((ys_system->xs)*A_3.t(),0)/(ys_system->n);
      grad_xs += xs_copy;
      grad_xs.each_row() -= sum((ys_system->xs)*A_1.t(),0)/(ys_system->n);
      grad_xs.each_row() += a_0.t();
    }

    virtual void compute_grad_weights(CoupledMcKeanVlasov *ys_system)
    {
      mat &ys = ys_system->xs;
      mat &yws_matrix = ys_system->ws_matrix;
      grad_xs = xs*(A_0+A_0.t());
      xs_copy = 2*xs;
      xs_copy.each_row() %= sum((ys%yws_matrix)*A_3.t(),0);
      grad_xs += xs_copy;
      grad_xs.each_row() -= sum((ys%yws_matrix)*A_1.t(),0);
      grad_xs.each_row() += a_0.t();
    }

    virtual void weights_update(CoupledMcKeanVlasov *ys_system)
    {
      mat &ys = ys_system->xs;
      mat &yws_matrix = ys_system->ws_matrix;
      upd_ws = -sum((xs*A_1)*(ys%yws_matrix).t(),1)
             + xs*a_0 + sum(((xs%xs)*A_3*(ys%yws_matrix).t()),1);
      upd_ws += diagvec((xs*A_0.t()*xs.t()));
    }

    virtual double computeNI(CoupledMcKeanVlasov *ys_system, CoupledMcKeanVlasov *xs_system2, 
                             CoupledMcKeanVlasov *ys_system2, CoupledMcKeanVlasov *ind_xs_system, 
                             CoupledMcKeanVlasov *ind_ys_system, int iter, double lr)
    {
      for (int i=0; i<iter; i++)
      {
        xs_system2->gd_step(ys_system, lr);
        ys_system2->gd_step(this, lr);
      }
      double loss_x = 1000000;
      double loss_y = -1000000;
      for (int j=0; j < xs_system2->n; j++)
      {
        *ind_xs_system = *xs_system2;
        ind_xs_system->restrict_row(j);
        if (loss_x > ind_xs_system->compute_loss(ys_system))
        {
          loss_x = ind_xs_system->compute_loss(ys_system);
        }
      }
      for (int j=0; j < ys_system2->n; j++)
      {
        *ind_ys_system = *ys_system2;
        ind_ys_system->restrict_row(j); 
        if (loss_y < compute_loss(ind_ys_system))
        {
          loss_y = compute_loss(ind_ys_system);
        }
      }
      return loss_y - loss_x;
    }

    virtual double computeNI_weights(CoupledMcKeanVlasov *ys_system, CoupledMcKeanVlasov *xs_system2, 
                             CoupledMcKeanVlasov *ys_system2, CoupledMcKeanVlasov *ind_xs_system, 
                             CoupledMcKeanVlasov *ind_ys_system, int iter, double lr, double avg)
    {
      for (int i=0; i<iter; i++)
      {
        xs_system2->transport_step_weights(ys_system, lr, avg); 
        ys_system2->transport_step_weights(this, lr, avg);
      }
      double loss_x = 1000000;
      double loss_y = -1000000;
      for (int j=0; j < xs_system2->n; j++)
      {
        *ind_xs_system = *xs_system2;
        ind_xs_system->restrict_row(j);
        if (loss_x > ind_xs_system->compute_loss_weights_full(ys_system, avg))
        {
          loss_x = ind_xs_system->compute_loss_weights_full(ys_system, avg);
        }
      }
      for (int j=0; j < ys_system2->n; j++)
      {
        *ind_ys_system = *ys_system2;
        ind_ys_system->restrict_row(j); 
        if (loss_y < compute_loss_weights_full(ind_ys_system, avg))
        {
          loss_y = compute_loss_weights_full(ind_ys_system, avg);
        }
      }
      return loss_y - loss_x;
    }
};


class ysSystem : public CoupledMcKeanVlasov {
  public:
    // compute gradients
    virtual void compute_grad(CoupledMcKeanVlasov *xs_system)
    {
      grad_xs = xs*(A_2+A_2.t());
      grad_xs.each_row() += sum((xs_system->xs)*A_1,0)/(xs_system->n);
      grad_xs.each_row() -= sum(((xs_system->xs)%(xs_system->xs))*A_3,0)/(xs_system->n);
      grad_xs.each_row() += a_1.t();
    }

    virtual void compute_grad_weights(CoupledMcKeanVlasov *xs_system)
    {
      mat &xxs = xs_system->xs;
      mat &xws_matrix = xs_system->ws_matrix;
      grad_xs = xs*(A_2+A_2.t());
      grad_xs.each_row() += sum((xxs%xws_matrix)*A_1,0);
      grad_xs.each_row() -= sum((xxs%xxs%xws_matrix)*A_3,0);
      grad_xs.each_row() += a_1.t();
    }

    virtual void weights_update(CoupledMcKeanVlasov *xs_system)
    {
      mat &xxs = xs_system->xs;
      mat &xws_matrix = xs_system->ws_matrix;
      upd_ws = sum((xs*A_1.t())*(xxs%xws_matrix).t(),1)
             + xs*a_1 - sum((xs*A_3.t()*(xxs%xxs%xws_matrix).t()),1);
      upd_ws += diagvec((xs*A_2.t()*xs.t())); 
    }
};


#endif
