(setv attach-dir ".")
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(import [sklearn.preprocessing [normalize]])

(defn partial-flatten [x]
  (np.reshape x (, (get (np.shape x) 0) -1)))

(defn mnist-data [[train-set "vanilla"] [test-set "vanilla"] [conv False]]
  (setv
    train-images (np.expand-dims (np.squeeze (with [f (open f"../../mnist_c/{train-set}/train_images.npy" "rb")] (np.load f))) 3)
    train-labels (with [f (open f"../../mnist_c/{train-set}/train_labels.npy" "rb")] (np.load f))
    test-images (np.expand-dims (np.squeeze (with [f (open f"../../mnist_c/{test-set}/test_images.npy" "rb")] (np.load f))) 3)
    test-labels (with [f (open f"../../mnist_c/{test-set}/test_labels.npy" "rb")] (np.load f))
    train-images (/ (if conv train-images (partial-flatten train-images)) (np.float32 255))
    test-images (/ (if conv test-images (partial-flatten test-images)) (np.float32 255)))
  (, train-images test-images train-labels test-labels))

(defn mnist-train-net [train-net input-shape [conv False] [optimizer None]]
  (setv
    [net-init net-apply] (stax.serial (unpack-iterable train-net))
    net-apply (jax.jit net-apply)
    [opt-init opt-update opt-get] (if (is optimizer None)
                                      (optimizers.sgd :step-size #_(optimizers.piecewise-constant [1500] [1e-1 1e-2])
                                                      (optimizers.exponential-decay
                                                        :step-size 1e-1
                                                        :decay-rate 0.99995
                                                        :decay-steps 1))
                                      (hy.eval optimizer))
    calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.ce-with-logits-loss (net-apply p x :rng rng) y)))
    new-opt-state (fn [rng] (opt-init (get (net-init rng (if conv
                                                             (, -1 input-shape input-shape 1)
                                                             (, -1 input-shape))) 1))))
  (, net-apply calc-loss opt-update opt-get new-opt-state))

(defn mnist-test-net [test-net]
  (setv net-apply (jax.jit (get (stax.serial (unpack-iterable test-net)) 1))
        calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.ce-with-logits-loss (net-apply p x :rng rng) y))))
  (, net-apply calc-loss))
(defn mnist-eval [train-images test-images train-labels test-labels train-apply test-apply
                  calc-loss-test epochs opt-get opt-step new-opt-state jax-rng batch-size attach-dir
                  [fname "perf.npy"] [splitter None] [label-noise 0.0]
                  [show-progress True]]
  (setv train-labels-one-hot (jax.nn.one-hot (nn-utils.add-label-noise train-labels label-noise) 10)
        test-labels-one-hot (jax.nn.one-hot test-labels 10)
        metrics [(nn-utils.setup-loss-tracker
                   calc-loss-test test-images test-labels-one-hot
                   opt-get 100)
                 (nn-utils.setup-accuracy-tracker
                   test-images test-labels-one-hot test-apply
                   opt-get True 100)
                 (nn-utils.setup-trace-tracker
                   train-apply opt-get
                   100)
                 (nn-utils.setup-determinant-tracker
                   train-apply opt-get
                   100)])
  (setv [_ #* subrng] (jax.random.split jax-rng (inc (.get-n-splits splitter)))
        perf (lfor [i data] (enumerate (.split splitter :X train-images :y train-labels-one-hot))
                   (let [[train test] data
                         x-train (jnp.take train-images train :axis 0)
                         ;; x-test (np.take train-images test :axis 0)
                         y-train (jnp.take train-labels-one-hot train :axis 0)
                         ;; y-test (np.take train-labels-one-hot test :axis 0)
                         n (np.size train)
                         [opt-state metrics]
                         (nn-utils.train-model
                           epochs (new-opt-state (get subrng i))
                           opt-step x-train y-train
                           :batch-size batch-size
                           :metrics metrics
                           :show-progress show-progress
                           :progress-pos 2
                           :jax-rng (get subrng i))]
                     [(:metric (:state (get metrics 0)))
                      (:metric (:state (get metrics 1)))
                      (:metric (:state (get metrics 2)))
                      (:metric (:state (get metrics 3)))]))
        loss (np.array (lfor p perf (get p 0)))
        acc (np.array (lfor p perf (get p 1)))
        trace (np.array (lfor p perf (get p 2)))
        det (np.array (lfor p perf (get p 3)))
        train-loss (np.take loss 0 :axis 2)
        train-acc (np.take acc 0 :axis 2)
        test-loss (np.take loss 1 :axis 2)
        test-acc (np.take acc 1 :axis 2)
        test-loss-mean (np.mean test-loss :axis 0)
        test-loss-std (np.std test-loss :axis 0)
        test-acc-mean (np.mean test-acc :axis 0)
        test-acc-std (np.std test-acc :axis 0)
        perf-file (os.path.join attach-dir fname))
  (print f"Test loss {(get test-loss-mean -1) :.4f} \pm ({(get test-loss-std -1) :.4f})")
  (print f"Test accuracy {(get test-acc-mean -1) :.4f} \pm ({(get test-acc-std -1)  :.4f})")
  (with [f (open perf-file "wb")]
    (np.save f train-loss)
    (np.save f train-acc)
    (np.save f test-loss)
    (np.save f test-acc)
    (np.save f trace)
    (np.save f det)))

(setv process-params `(setv num-outputs 10
                            train-net
                            [(stax.Conv :out-chan 6 :filter-shape (, 5 5) :padding "SAME")
                             stax.Relu
                             (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
                             (stax.Conv :out-chan 16 :filter-shape (, 5 5) :padding "SAME")
                             stax.Relu
                             (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
                             stax.Flatten
                             (stax.Dense 120)
                             stax.Relu
                             (stax.Dense 84)
                             stax.Relu
                             (stax.Dense num-outputs)]
                            test-net
                            [(stax.Conv :out-chan 6 :filter-shape (, 5 5) :padding "SAME")
                             stax.Relu
                             (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
                             (stax.Conv :out-chan 16 :filter-shape (, 5 5) :padding "SAME")
                             stax.Relu
                             (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
                             stax.Flatten
                             (stax.Dense 120)
                             stax.Relu
                             (stax.Dense 84)
                             stax.Relu
                             (stax.Dense num-outputs)]
                            create-penalty (fn [net-apply calc-loss] (constantly 0.0))

                            train-set "motion_blur"
                            optimizer None
                            conv True
                            epochs 100
                            batch-size 32
                            jax-rng (jax.random.PRNGKey 0))
      label-noise (np.arange 0 1.1 0.1)
      train-size (np.concatenate [(np.arange 250 1e3 250)
                                  (np.arange 1e3 2e3 500)
                                  (np.arange 2e3 1e4 1e3)]))


(import [sklearn.model_selection [StratifiedShuffleSplit]]
        [multiprocessing :as mp])

(defn run-testbench [d q]
  (assoc os.environ "CUDA_VISIBLE_DEVICES" (str d))

  (hy.eval process-params (globals))

  (print (jax.devices "gpu"))
  (setv [train-images test-images train-labels test-labels] (mnist-data :train-set train-set :conv conv)
        ;; dev (get (jax.devices "gpu") d)
        ;; train-images (jax.device-put train-images dev)
        ;; test-images (jax.device-put test-images dev)
        ;; train-labels (jax.device-put train-labels dev)
        ;; test-labels (jax.device-put test-labels dev)
        input-shape (get (np.shape train-images) 1)
        [train-apply calc-loss-train opt-update opt-get new-opt-state] (mnist-train-net train-net input-shape :conv conv :optimizer optimizer)
        opt-step (nn-utils.create-opt-step calc-loss-train (create-penalty train-apply calc-loss-train) opt-update opt-get)
        [test-apply calc-loss-test] (mnist-test-net test-net))

  (while (not (.empty q))
    (setv [l t] (.get q)
          splitter (StratifiedShuffleSplit :n-splits 10
                                           :train-size (/ t 60e3)
                                           :random-state 2))
    (mnist-eval train-images test-images train-labels test-labels train-apply test-apply
                calc-loss-test epochs opt-get opt-step new-opt-state jax-rng
                batch-size attach-dir f"perf_{l :.1f}_{t :.0f}.npy"
                splitter l :show-progress False)))

(defmain [#* args]
  (mp.set-start-method "spawn")
  (setv [g1 g2] (np.meshgrid label-noise train-size)
        grid (np.vstack [(np.reshape g1 (, -1)) (np.reshape g2 (, -1))])
        q (mp.Queue)
        procs (list))
  (for [i (range (get (np.shape grid) 1))]
    (.put q (np.take grid i :axis 1)))

  (setv s (.qsize q))

  (for [d (range (int (get args 1)))]
    (setv p (mp.Process :target run-testbench :args (, d q)))
    (.start p)
    (.append procs p))

  (with [t (trange (.qsize q))]
    (while (not (.empty q))
      (setv ss (.qsize q))
      (when (> s ss)
        (.update t (- s ss))
        (setv s ss))))

  (mp.connection.wait (lfor p procs (. p sentinel))))
