R"""Script for running GLS on synthetic datasets.


cd ~/Desktop/projects/zonotopic_relu
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/zonotopic_relu


python3 scripts/synthetic/gls.py \
    --outdir="/tmp" \
    --configs_path="exps.synthetic.gls_configs.CONFIGS" \
    --config="test" \
    --n_runs=2

"""
import dataclasses
import json
import os
from pydoc import locate

from absl import app
from absl import flags
from absl import logging

from xoid.datasets import synthetic
from xoid.gls import base_gls

from xoid.util import misc_util


FLAGS = flags.FLAGS

_CONFIGS_PATH = "exps.synthetic.gls_configs.CONFIGS"

if __name__ == "__main__":
    # Directory should already exist.
    flags.DEFINE_string('outdir', None, 'Path directory to create where we will write output to.')

    flags.DEFINE_string('configs_path', _CONFIGS_PATH, 'Python path to configs dict.')
    flags.DEFINE_string('config', None, 'Name of the entry in the configs dict to use as configuration.')

    flags.DEFINE_integer('n_runs', 1, 'Number of times to repeat the experiment.')

    flags.mark_flags_as_required(['outdir', 'configs_path', 'config'])


@dataclasses.dataclass()
class Config:
    name: str

    m_gen: int
    m_train: int
    d: int

    max_steps: int

    eps: float = 1e-7


def do_run(cfg, run_index):
    # Make dataset.
    N = (cfg.d + 1) * cfg.m_gen
    X, Y = synthetic.make_dataset(cfg.d, cfg.m_gen, N)

    # Run GLS
    v = misc_util.make_pm_1_v(cfg.m_train, X.dtype)

    options = base_gls.GlsOptions(loss_fn='l2')
    gls = base_gls.Gls(X, Y, v, options, eps=cfg.eps)

    loss, steps_taken = gls.solve(cfg.max_steps)
    return loss, steps_taken


def main(_):
    cfg = locate(FLAGS.configs_path)[FLAGS.config]

    losses = []
    steps_takens = []
    for i in range(FLAGS.n_runs):
        loss, steps_taken = do_run(cfg, i)
        losses.append(loss)
        steps_takens.append(steps_taken)

    results = {
        'final_losses': losses,
        'steps_takens': steps_takens,
        'config': dataclasses.asdict(cfg),
    }

    filepath = os.path.join(FLAGS.outdir, f'gls_{cfg.name}.json')
    filepath = os.path.expanduser(filepath)
    with open(filepath, 'w') as f:
        json.dump(results, f)


if __name__ == "__main__":
    app.run(main)
