(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(import [sklearn.preprocessing [normalize]]
        [imagecorruptions [corrupt]])

(defn partial-flatten [x]
  (np.reshape x (, (get (np.shape x) 0) -1)))

(defn fk-data [[train-size 1500] [conv False]]
  (setv images (np.transpose (get (np.load "../../facial_keypoints/face_images.npz") "face_images") (, 2 0 1))
        keypoints (np.genfromtxt "../../facial_keypoints/facial_keypoints.csv" :skip-header 1 :delimiter ",")
        non-nan (np.squeeze (np.where (np.logical-not (np.any (np.isnan keypoints) :axis 1))))
        images (np.take images non-nan :axis 0)
        images (.astype images "uint8")
        keypoints (np.take keypoints non-nan :axis 0))

  (, images keypoints (if conv (get (np.shape images) 1)
                          (* (get (np.shape images) 1)
                             (get (np.shape images) 2))))
  #_(, (np.take images (np.arange train-size) :axis 0)
       (np.take images (np.arange train-size (get (np.shape images) 0)) :axis 0)
       (np.take keypoints (np.arange train-size) :axis 0)
       (np.take keypoints (np.arange train-size (get (np.shape images) 0)) :axis 0)))

(defn fk-train-net [train-net input-shape [conv False] [optimizer None]]
  (setv
    [net-init net-apply] (stax.serial (unpack-iterable train-net))
    net-apply (jax.jit net-apply)
    [opt-init opt-update opt-get] (if (is optimizer None)
                                      (optimizers.sgd :step-size #_1e-5 (optimizers.exponential-decay
                                                                   :step-size 1e-4
                                                                   :decay-rate 0.99995
                                                                   :decay-steps 1))
                                      (hy.eval optimizer))
    ;; [opt-init opt-update opt-get] (optimizers.sgd :step-size step-size)
    calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.mse-loss (net-apply p x :rng rng) y)))
    new-opt-state (fn [rng] (opt-init (get (net-init rng (if conv
                                                             (, -1 input-shape input-shape 1)
                                                             (, -1 input-shape))) 1))))
  (, net-apply calc-loss opt-update opt-get new-opt-state))

(defn fk-test-net [test-net]
  (setv net-apply (jax.jit (get (stax.serial (unpack-iterable test-net)) 1))
        calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.mse-loss (net-apply p x :rng rng) y))))
  (, net-apply calc-loss))

(setv num-outputs 30
      net
      [(stax.Conv :out-chan 6 :filter-shape (, 5 5) :padding "SAME")
       stax.Relu
       (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
       (stax.Conv :out-chan 16 :filter-shape (, 5 5) :padding "SAME")
       stax.Relu
       (stax.AvgPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")
       stax.Flatten
       (stax.Dense 120)
       stax.Relu
       (stax.Dense 84)
       stax.Relu
       (stax.Dense num-outputs)]
      train-net (.copy net)
      test-net (.copy net)
      conv True
      corruption "defocus_blur"
      create-opt-step (fn [params]
                        (setv train-net (.copy net))
                        (.insert train-net 1 (stax.BatchNorm))
                        (.insert train-net 5 (stax.BatchNorm))
                        (.insert train-net 11 (stax.Dropout :rate (get params "rate")))
                        (.insert train-net 13 (stax.Dropout :rate (get params "rate")))
                        (setv
                          [train-apply calc-loss-train opt-update opt-get new-opt-state]
                          (fk-train-net train-net input-shape :conv conv))
                        (nn-utils.create-opt-step calc-loss-train
                                                  (constantly 0.0)
                                                  opt-update opt-get))
      param-space {"rate" [2e-1 3e-1 4e-1 5e-1 6e-1 7e-1 8e-1 9e-1]})

(.insert train-net 1 (stax.BatchNorm))
(.insert train-net 5 (stax.BatchNorm))
(.insert train-net 11 (stax.Dropout :rate 0.1))
(.insert train-net 13 (stax.Dropout :rate 0.1))

(.insert test-net 1 (stax.BatchNorm))
(.insert test-net 5 (stax.BatchNorm))
(.insert test-net 11 (stax.Dropout :rate 0.1 :mode "test"))
(.insert test-net 13 (stax.Dropout :rate 0.1 :mode "test"))

(defn fk-augment [ims conv np-rng corruption [severity 1] [corruption-prob 1]]
  (setv to-corrupt (.binomial np-rng 1 corruption-prob (get (np.shape ims) 0)))
  (as-> ims I
        (if (is corruption None) I
            (jnp.array (lfor [i o] (enumerate I)
                             (if (= (get to-corrupt i) 1)
                                 (np.mean (corrupt o :corruption-name corruption
                                                   :severity severity) :axis 2)
                                 o))))
        (if conv (jnp.expand-dims I 3) (partial-flatten I))
        (/ I 255.0)))

(import [sklearn.model_selection [ShuffleSplit]])

(setv [images keypoints input-shape] (fk-data :conv conv)
      [train-apply calc-loss-train opt-update opt-get new-opt-state]
      (fk-train-net train-net input-shape :conv conv)
      [test-apply calc-loss-test] (fk-test-net test-net)

      batch-size (default batch-size 128)
      epochs (default epochs 50)
      noise-scale (default noise-scale 15)
      train-size (default train-size 0.3)
      corruption (default corruption None)

      np-rng-count (itertools.count 0)

      splitter (ShuffleSplit :n-splits 5
                             :train-size train-size
                             :random-state 62)
      create-iterator (fn [] (gfor [train test] (.split splitter :X images :y keypoints)
                                   :do (setv np-rng (np.random.default-rng :seed (next np-rng-count))
                                             y-train (jnp.take keypoints train :axis 0))
                                   [(fk-augment (jnp.take images train :axis 0) conv np-rng corruption) ; train images
                                    (fk-augment (jnp.take images test :axis 0) conv np-rng None) ; test images
                                    (+ y-train (np.clip (.normal np-rng
                                                                 (- noise-scale)
                                                                 noise-scale
                                                                 :size (jnp.shape y-train))
                                                        0 96)) ; train keypoints
                                    (jnp.take keypoints test :axis 0)])) ; test keypoints
      metric (fn [opt-state x y] (calc-loss-test (opt-get opt-state) x y :rng (jax.random.PRNGKey 0)))
      perf (np.array (nn-utils.grid-search-cv param-space epochs create-opt-step new-opt-state
                                              batch-size metric create-iterator))
      best-param (get (np.take perf 0 :axis 1) (np.argmin (np.take perf 1 :axis 1))))

(print perf)
(print f"Results of grid search: {best-param}")
