(setv attach-dir (quote "."))
(setv mgs-dir (quote "../alexnet_mnist/mgs_trace"))
(setv lg2-dir (quote "../alexnet_mnist/loss_gradient_parameter"))
(setv lg1-dir (quote "../alexnet_mnist/loss_gradient_data"))
(setv dropout-dir (quote "../alexnet_mnist/dropout"))
(setv wp-dir (quote "../alexnet_mnist/weight_penalty"))
(setv lr-dir (quote "../alexnet_mnist/adam"))
(setv vanilla-dir (quote "y"))
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(setv dirs [(, "Unregularised" "tab:gray" vanilla-dir)
            (, "Learning rate (ADAM)" "tab:orange" lr-dir)
            (, "Dropout" "tab:brown" dropout-dir)
            (, "Weight penalty" "tab:blue" wp-dir)
            (, "Loss gradient penalty (data)" "tab:green" lg1-dir)
            (, "Loss gradient penalty (parameter)" "tab:purple" lg2-dir)
            (, "MGS penalty (trace)" "tab:red" mgs-dir)])

(setv fname (os.path.join attach-dir "mnist-metrics-plot-2.svg"))

(plt.clf)

(do
  (setv fig (plt.figure :figsize (, 17 5))
        acc-ax (plt.subplot2grid (, 1 3) (, 0 0) 1 1)
        tr-ax (plt.subplot2grid (, 1 3) (, 0 1) 1 1)
        det-ax (plt.subplot2grid (, 1 3) (, 0 2) 1 1))

  (for [[_ c attach-dir] dirs]
    (with [f (open (os.path.join attach-dir "perf.npy") "rb")]
      (for [i (range 3)] (np.load f))
      (setv acc (np.mean (np.load f) :axis 0)
            tr (np.log (np.mean (np.load f) :axis 0))
            det (np.mean (np.load f) :axis 0))
      (.plot acc-ax (np.linspace 0 30 (np.size acc)) acc :c c)
      (.plot tr-ax (np.linspace 0 30 (np.size tr)) tr :c c)
      (.plot det-ax (np.linspace 0 30 (np.size det)) det :c c)))

  (.set-title acc-ax "Aggregated test accuracy" :fontsize 18 :pad 15)
  (.set-title tr-ax r"$\log(\mathtt{tr}\/K_{\theta})$" :fontsize 18 :pad 15)
  (.set-title det-ax r"$\log(\mathtt{det}\/K_{\theta})$" :fontsize 18 :pad 15)


  (setv names (lfor [n _ _] dirs n))
  (.legend fig names :loc "lower center" :ncol 4 :fontsize 15)

  (plt.tight-layout)
  (plt.subplots-adjust :hspace 0.2 :wspace 0.15 :left 0.03 :top 0.9 :bottom 0.25))

(plt.savefig fname)

(print fname)
