(setv attach-dir ".")
(setv mgs2-dir (quote "../two_circles/mgs_det"))
(setv mgs1-dir (quote "../two_circles/mgs_trace"))
(setv rkhs-dir (quote "../two_circles/functional_regularisation"))
(setv lg2-dir (quote "../two_circles/loss_gradient_data"))
(setv lg1-dir (quote "../two_circles/loss_gradient_parameter"))
(setv dropout-dir (quote "../two_circles/dropout"))
(setv wp-dir (quote "../two_circles/weight_penalty"))
(setv lr-dir (quote "../two_circles/learning_rate"))
(setv vanilla-dir (quote "../two_circles/unregularised"))
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(setv net [(nt-stax.Dense 300 :parameterization "standard")
           (nt-stax.Relu)
           (nt-stax.Dense 300 :parameterization "standard")
           (nt-stax.Relu)
           (nt-stax.Dense 300 :parameterization "standard")
           (nt-stax.Relu)
           (nt-stax.Dense 1 :parameterization "standard")])

(import [sklearn.datasets [make_circles]])

(defn two-circles-setup [net step-size [class-sizes [150 75]] [label-noise 0.0]
                         [data-noise 0.15] [data-seed 2] [np-rng-seed 3]]
  (setv
    [x y] (make-circles :n-samples class-sizes
                        :noise data-noise
                        :random-state data-seed)
    np-rng (np.random.default-rng np-rng-seed)
    n (get (np.shape y) 0)
    flipped-idx (.choice np-rng (range 0 n) (int (np.floor (* label-noise n))) :replace False)
    flipped-labels (.integers np-rng 0 (inc (np.max y)) (get (np.shape flipped-idx) 0))
    _ (assoc y flipped-idx flipped-labels)
    y (np.reshape y (, -1 1))

    [net-init net-apply _] (nt-stax.serial (unpack-iterable net))
    net-apply (jax.jit net-apply)
    [opt-init opt-update opt-get] (optimizers.sgd :step-size step-size)
    calc-loss (fn [p x y [rng None]] (nn-utils.bce-with-logits-loss (net-apply p x :rng rng) y))
    input-shape (get (np.shape x) 1)
    new-opt-state (fn [rng] (opt-init (get (net-init rng (, -1 input-shape)) 1))))
  (, x y net-apply calc-loss opt-update opt-get new-opt-state))

(setv dirs [(, "Unregularised" "tab:gray" vanilla-dir True)
            (, "Learning rate (Inverse time decay)" "tab:orange" lr-dir False)
            (, "Dropout" "tab:brown" dropout-dir False)
            (, "Weight penalty" "tab:blue" wp-dir True)
            (, "Loss gradient penalty (data)" "tab:green" lg1-dir False)
            (, "Loss gradient penalty (parameter)" "tab:purple" lg2-dir False)
            (, "RKHS norm penalty" "tab:olive" rkhs-dir False)
            (, "MGS penalty (trace)" "tab:red" mgs1-dir False)
            (, "MGS penalty (det)" "tab:pink" mgs2-dir True)])

(setv fname (os.path.join attach-dir "two-circles-metrics-plot.svg"))

(setv [x y net-apply _ _ opt-get _] (two-circles-setup net (optimizers.constant 1e-1))
      [x-train x-test y-train y-test] (train-test-split x y
                                                        :test-size 0.5
                                                        :random-state 3)
      jax-rng (jax.random.PRNGKey 2))

(do
  (setv fig (plt.figure :figsize (, 18 11))
        tr-ax (plt.subplot2grid (, 2 6) (, 1 0) 1 3)
        det-ax (plt.subplot2grid (, 2 6) (, 1 3) 1 3)
        i (itertools.count))

  (for [[n c attach-dir plot-boundary] dirs]
    (when plot-boundary
      (with [f (open (os.path.join attach-dir "opt-state.obj") "rb")]
        (setv opt-state (optimizers.pack-optimizer-state (pickle.load f)))
        (plt.subplot2grid (, 2 6) (, 0 (* (next i) 2)) 1 2)
        (nn-utils.plot-decision-boundary (fn [x] (-> (net-apply (opt-get opt-state) x :rng jax-rng)
                                                     (jax.nn.sigmoid)
                                                     (> 0.5)))
                                         x-train
                                         :bounds (, -1.2 1.2 -1.2 1.2)
                                         :alpha 0.5)
        (plt.scatter (np.take x-train 0 :axis 1)
                     (np.take x-train 1 :axis 1)
                     :c y-train
                     :edgecolors "black"
                     :cmap plt.cm.viridis
                     :s 70)
        (plt.title n :fontsize 18 :pad 15)
        (plt.xlim -1.2 1.2)
        (plt.ylim -1.2 1.2)))
    (with [f (open (os.path.join attach-dir "ntk-trace.npy") "rb")]
      (.plot tr-ax (np.log (np.load f)) :c c))
    (with [f (open (os.path.join attach-dir "ntk-determinant.npy") "rb")]
      (.plot det-ax (np.load f) :c c)))

  (.set-title tr-ax r"$\log(\mathtt{tr}\/K_{\theta})$" :fontsize 18 :pad 15)
  (.set-title det-ax r"$\log(\mathtt{det}\/K_{\theta})$" :fontsize 18 :pad 15)

  (setv names (lfor [n _ _ _] dirs n))
  (.legend fig names :loc "lower center" :ncol 5 :fontsize 15)

  (plt.tight-layout)
  (plt.subplots-adjust :hspace 0.2 :wspace 0.25 :left 0.03 :top 0.95 :bottom 0.1))
(plt.savefig fname)

(print fname)
