#include "run.h"
#include "data.h"
#include "log.h"

#include "moss.h"
#include "umoss.h"
#include "gaussian_bandit.h"

#include <cstring>
#include <random>
#include <iostream>
#include <chrono>
#include <thread>

using namespace std;


void exp2() {
  /* seed random number generate */
  default_random_engine gen;
  random_device rd;
  gen.seed(rd());

  /* initialise some algorithms (just watch me free this...) */
  int n = 5000;

  vector<double> bias = {sqrt(n)};
  
  double H = 0.0;
  for (int i = 1;i!=9;++i) {
    H+=1.0 / i;
  }

  for (int i = 1;i!=10;++i) {
    bias.push_back(i * H * sqrt(n));
  }

  vector<BanditAlgorithm*> algs = {
    new MOSS(4.0),
    new UMOSS(4.0, bias),
  };


  Logger<LogEntry> log("data/fig2.log");
  for (int t = 0;t!=500;++t) {
    cout << "running trial: " << t << "\n";
    for (unsigned int i = 0;i != algs.size();++i) {
      for (int j = 0;j!=10;++j) {
        for (double delta = 0.01;delta <= 0.5;delta+=0.01) {
          vector<double> mus(10, 0.0);
          mus[j] = delta;
          GaussianBandit p(n, mus, gen);
          double regret = RunData(*algs[i], p, false).regret;
          double theta = 0.5 * j + delta;
          log.log(LogEntry(i, theta, regret));
        }
      }
    }

    if (t % 20 == 0) {
      log.save();
    }
  }
  log.save();
}


void exp1() {
  /* seed random number generate */
  default_random_engine gen;
  random_device rd;
  gen.seed(rd());

  /* initialise some algorithms (just watch me free this...) */
  int n = 5000;
  vector<BanditAlgorithm*> algs = {
    new MOSS(4.0),
    new UMOSS(4.0, {pow((double)n, 1.0 / 3), pow((double)n, 2.0 / 3)}),
  };


  Logger<LogEntry> log("data/fig1.log");
  for (int t = 0;t!=500;++t) {
    cout << "running trial: " << t << "\n";
    for (unsigned int i = 0;i != algs.size();++i) {
      for (double delta = -0.5;delta <= 0.5;delta+=0.002) {
        vector<double> mus = {0, -delta};
        GaussianBandit p(n, mus, gen);
        double regret = RunData(*algs[i], p, false).regret;
        log.log(LogEntry(i, delta, regret));
      }
    }

    if (t % 20 == 0) {
      log.save();
    }
  }
  log.save();
}


int main() {
  exp1();
  exp2();
}



