(setv attach-dir ".")
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(import [sklearn.preprocessing [normalize]])

(defn partial-flatten [x]
  (np.reshape x (, (get (np.shape x) 0) -1)))

(defn mnist-data [[train-set "vanilla"] [test-set "vanilla"] [conv False]]
  (setv
    train-images (np.expand-dims (np.squeeze (with [f (open f"../../mnist_c/{train-set}/train_images.npy" "rb")] (np.load f))) 3)
    train-labels (with [f (open f"../../mnist_c/{train-set}/train_labels.npy" "rb")] (np.load f))
    test-images (np.expand-dims (np.squeeze (with [f (open f"../../mnist_c/{test-set}/test_images.npy" "rb")] (np.load f))) 3)
    test-labels (with [f (open f"../../mnist_c/{test-set}/test_labels.npy" "rb")] (np.load f))
    train-images (/ (if conv train-images (partial-flatten train-images)) (np.float32 255))
    test-images (/ (if conv test-images (partial-flatten test-images)) (np.float32 255)))
  (, train-images test-images train-labels test-labels))

(defn mnist-train-net [train-net input-shape [conv False] [optimizer None]]
  (setv
    [net-init net-apply] (stax.serial (unpack-iterable train-net))
    net-apply (jax.jit net-apply)
    [opt-init opt-update opt-get] (if (is optimizer None)
                                      (optimizers.sgd :step-size #_(optimizers.piecewise-constant [1500] [1e-1 1e-2])
                                                      (optimizers.exponential-decay
                                                        :step-size 1e-1
                                                        :decay-rate 0.99995
                                                        :decay-steps 1))
                                      (hy.eval optimizer))
    calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.ce-with-logits-loss (net-apply p x :rng rng) y)))
    new-opt-state (fn [rng] (opt-init (get (net-init rng (if conv
                                                             (, -1 input-shape input-shape 1)
                                                             (, -1 input-shape))) 1))))
  (, net-apply calc-loss opt-update opt-get new-opt-state))

(defn mnist-test-net [test-net]
  (setv net-apply (jax.jit (get (stax.serial (unpack-iterable test-net)) 1))
        calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.ce-with-logits-loss (net-apply p x :rng rng) y))))
  (, net-apply calc-loss))
(import [sklearn.model_selection [StratifiedShuffleSplit]])

(defn mnist-eval-simple [train-net test-net create-penalty epochs batch-size label-noise train-size optimizer
                         jax-rng attach-dir conv
                         [fname "perf.npy"] [show-progress True]
                         [n-splits 10] [train-set "motion_blur"] [test-set "vanilla"]]
  (setv

    ;; Setup
    [train-images test-images train-labels test-labels] (mnist-data :train-set train-set
                                                                    :test-set test-set
                                                                    :conv conv)
    input-shape (get (np.shape train-images) 1)
    [train-apply calc-loss-train opt-update opt-get new-opt-state]
    (mnist-train-net train-net input-shape :conv conv
                     :optimizer optimizer)
    [test-apply calc-loss-test] (mnist-test-net test-net)
    penalty (create-penalty train-apply calc-loss-train)
    opt-step (nn-utils.create-opt-step calc-loss-train penalty opt-update opt-get)
    splitter (StratifiedShuffleSplit :n-splits n-splits :train-size train-size)
    train-labels-one-hot (jax.nn.one-hot (nn-utils.add-label-noise train-labels label-noise) 10)
    test-labels-one-hot (jax.nn.one-hot test-labels 10)
    metrics [(nn-utils.setup-loss-tracker
               calc-loss-test test-images test-labels-one-hot
               opt-get 100)
             (nn-utils.setup-accuracy-tracker
               test-images test-labels-one-hot test-apply
               opt-get True 100)
             (nn-utils.setup-trace-tracker
               train-apply opt-get
               100)
             (nn-utils.setup-determinant-tracker
               train-apply opt-get
               100)]

    ;; Training
    [_ #* subrng] (jax.random.split jax-rng (inc (.get-n-splits splitter)))
    perf (lfor [i data] (enumerate (.split splitter :X train-images :y train-labels-one-hot))
               (let [[train test] data
                     x-train (jnp.take train-images train :axis 0)
                     ;; x-test (np.take train-images test :axis 0)
                     y-train (jnp.take train-labels-one-hot train :axis 0)
                     ;; y-test (np.take train-labels-one-hot test :axis 0)
                     n (np.size train)
                     [opt-state metrics]
                     (nn-utils.train-model
                       epochs (new-opt-state (get subrng i))
                       opt-step x-train y-train
                       :batch-size batch-size
                       :metrics metrics
                       :show-progress show-progress
                       :progress-pos 2
                       :jax-rng (get subrng i))]
                 [(:metric (:state (get metrics 0)))
                  (:metric (:state (get metrics 1)))
                  (:metric (:state (get metrics 2)))
                  (:metric (:state (get metrics 3)))]))

    ;; Save metrics
    loss (np.array (lfor p perf (get p 0)))
    acc (np.array (lfor p perf (get p 1)))
    trace (np.array (lfor p perf (get p 2)))
    det (np.array (lfor p perf (get p 3)))
    train-loss (np.take loss 0 :axis 2)
    train-acc (np.take acc 0 :axis 2)
    test-loss (np.take loss 1 :axis 2)
    test-acc (np.take acc 1 :axis 2)
    test-loss-mean (np.mean test-loss :axis 0)
    test-loss-std (np.std test-loss :axis 0)
    test-acc-mean (np.mean test-acc :axis 0)
    test-acc-std (np.std test-acc :axis 0)
    perf-file (os.path.join attach-dir fname))
  (print f"Test loss {(get test-loss-mean -1) :.4f} \pm ({(get test-loss-std -1) :.4f})")
  (print f"Test accuracy {(get test-acc-mean -1) :.4f} \pm ({(get test-acc-std -1)  :.4f})")
  (with [f (open perf-file "wb")]
    (np.save f train-loss)
    (np.save f train-acc)
    (np.save f test-loss)
    (np.save f test-acc)
    (np.save f trace)
    (np.save f det)))

(setv num-outputs 10
      net
      [(stax.Conv :out-chan 6 :filter-shape (, 5 5) :padding "SAME")
       stax.Relu
       (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
       (stax.Conv :out-chan 16 :filter-shape (, 5 5) :padding "SAME")
       stax.Relu
       (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
       stax.Flatten
       (stax.Dense 120)
       stax.Relu
       (stax.Dense 84)
       stax.Relu
       (stax.Dense num-outputs)]
      conv True

      train-net (.copy net)
      test-net (.copy net)

      create-penalty (fn [net-apply calc-loss] (constantly 0.0)))


(.insert train-net 1 (stax.BatchNorm))
(.insert train-net 5 (stax.BatchNorm))
(.insert train-net 11 (stax.Dropout :rate 0.4))
(.insert train-net 13 (stax.Dropout :rate 0.4))

(.insert test-net 1 (stax.BatchNorm))
(.insert test-net 5 (stax.BatchNorm))
(.insert test-net 11 (stax.Dropout :rate 0.0 :mode "test"))
(.insert test-net 13 (stax.Dropout :rate 0.0 :mode "test"))

(setv training-size (/ 3000 60e3)
      label-noise 0.5
      epochs 100
      batch-size 32
      optimizer '(optimizers.sgd :step-size (optimizers.exponential-decay
                                                  :step-size 1e-1
                                                  :decay-rate 0.99995
                                                  :decay-steps 1))

      jax-rng (jax.random.PRNGKey 0))

(setv train-size (/ (np.array [10000 3000 1500 500]) 60e3))

(for [t train-size]
  (mnist-eval-simple :train-net train-net
                     :test-net test-net
                     :create-penalty create-penalty
                     :epochs epochs
                     :batch-size batch-size
                     :label-noise label-noise
                     :train-size t
                     :optimizer optimizer
                     :jax-rng jax-rng
                     :attach-dir attach-dir
                     :conv conv
                     :fname f"robustness_training-size_{t :.2f}.npy"
                     :n-splits (default n-splits 10)
                     :train-set (default train-set "motion_blur")
                     :test-set (default test-set "vanilla")))
(setv training-size (/ 3000 60e3)
      label-noise 0.5
      epochs 100
      batch-size 32
      optimizer '(optimizers.sgd :step-size (optimizers.exponential-decay
                                                  :step-size 1e-1
                                                  :decay-rate 0.99995
                                                  :decay-steps 1))

      jax-rng (jax.random.PRNGKey 0))

(setv label-noise (np.array [0.0 0.3 0.5 0.8]))

(for [l label-noise]
  (mnist-eval-simple :train-net train-net
                     :test-net test-net
                     :create-penalty create-penalty
                     :epochs epochs
                     :batch-size batch-size
                     :label-noise l
                     :train-size training-size
                     :optimizer optimizer
                     :jax-rng jax-rng
                     :attach-dir attach-dir
                     :conv conv
                     :fname f"robustness_label-noise_{l :.1f}.npy"
                     :n-splits (default n-splits 10)
                     :train-set (default train-set "motion_blur")
                     :test-set (default test-set "vanilla")))
(setv num-outputs 10
      net
      [(stax.Conv :out-chan 6 :filter-shape (, 5 5) :padding "SAME")
       stax.Relu
       (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
       (stax.Conv :out-chan 16 :filter-shape (, 5 5) :padding "SAME")
       stax.Relu
       (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
       stax.Flatten
       (stax.Dense 120)
       stax.Relu
       (stax.Dense 84)
       stax.Relu
       (stax.Dense num-outputs)]
      conv True

      train-net (.copy net)
      test-net (.copy net)

      create-penalty (fn [net-apply calc-loss] (nn-utils.create-mgs-penalty-trace net-apply 1e-2)))
(setv training-size (/ 3000 60e3)
      label-noise 0.5
      epochs 100
      batch-size 32
      optimizer '(optimizers.sgd :step-size (optimizers.exponential-decay
                                                  :step-size 1e-1
                                                  :decay-rate 0.99995
                                                  :decay-steps 1))

      jax-rng (jax.random.PRNGKey 0))

(setv batch-size (np.array [16 32 64 128]))

(for [b batch-size]
  (mnist-eval-simple :train-net train-net
                     :test-net test-net
                     :create-penalty create-penalty
                     :epochs epochs
                     :batch-size b
                     :label-noise label-noise
                     :train-size training-size
                     :optimizer optimizer
                     :jax-rng jax-rng
                     :attach-dir attach-dir
                     :conv conv
                     :fname f"robustness_batch-size_{b}.npy"
                     :n-splits (default n-splits 10)
                     :train-set (default train-set "motion_blur")
                     :test-set (default test-set "vanilla")))
(setv training-size (/ 3000 60e3)
      label-noise 0.5
      epochs 100
      batch-size 32
      optimizer '(optimizers.sgd :step-size (optimizers.exponential-decay
                                                  :step-size 1e-1
                                                  :decay-rate 0.99995
                                                  :decay-steps 1))

      jax-rng (jax.random.PRNGKey 0))

(setv learning-rate (np.array [5e-1 1e-1 1e-2 1e-3]))

(for [lr learning-rate]
  (mnist-eval-simple :train-net train-net
                     :test-net test-net
                     :create-penalty create-penalty
                     :epochs epochs
                     :batch-size batch-size
                     :label-noise label-noise
                     :train-size training-size
                     :optimizer '(optimizers.sgd :step-size (optimizers.exponential-decay
                                                  :step-size lr
                                                  :decay-rate 0.99995
                                                  :decay-steps 1))
                     :jax-rng jax-rng
                     :attach-dir attach-dir
                     :conv conv
                     :fname f"robustness_learning-rate_{lr :.3f}.npy"
                     :n-splits (default n-splits 10)
                     :train-set (default train-set "motion_blur")
                     :test-set (default test-set "vanilla")))
(setv training-size (/ 3000 60e3)
      label-noise 0.5
      epochs 100
      batch-size 32
      optimizer '(optimizers.sgd :step-size (optimizers.exponential-decay
                                                  :step-size 1e-1
                                                  :decay-rate 0.99995
                                                  :decay-steps 1))

      jax-rng (jax.random.PRNGKey 0))

(setv epochs (np.array [25 50 100 150]))

(for [e epochs]
  (mnist-eval-simple :train-net train-net
                     :test-net test-net
                     :create-penalty create-penalty
                     :epochs e
                     :batch-size batch-size
                     :label-noise label-noise
                     :train-size training-size
                     :optimizer optimizer
                     :jax-rng jax-rng
                     :attach-dir attach-dir
                     :conv conv
                     :fname f"robustness_epochs_{e}.npy"
                     :n-splits (default n-splits 10)
                     :train-set (default train-set "motion_blur")
                     :test-set (default test-set "vanilla")))
