# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Training Denoisers by Sun et al. """
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# objax
import objax

# custom
from utils.datasets import load_dataset, load_test_batch
from utils.models import load_denoiser, save_network_parameters
from utils.learner import train_denoiser, valid_denoiser


"""
    Global configurations
"""
_seed       = 215
_best_loss  = 1e6


"""
    Dataset specific configurations
"""
_dataset    = 'cifar10'

## SVHN
if 'svhn' == _dataset:
    _num_batchs = 256
    _num_epochs = 90
    _optimizer  = 'Adam'
    _learn_rate = 0.001
    _decay_rate = 1e-4
    _schedulelr = [30, 60, 90]
    _schedratio = 0.2
    _noise_level= 0.25


## CIFAR-10
if 'cifar10' == _dataset:
    _num_batchs = 256
    _num_epochs = 120
    _optimizer  = 'Adam'
    _learn_rate = 0.001
    _decay_rate = 1e-4
    _schedulelr = [40, 80, 120]
    _schedratio = 0.2
    _noise_level= 0.25


if __name__ == "__main__":

    # load the dataset
    (X_train, Y_train), (X_test, Y_test) = load_dataset(_dataset)
    print (' : load the dataset - {}'.format(_dataset))

    # load the network
    model = load_denoiser(_dataset)
    print (' : use the network - {}'.format(type(model).__name__))
    print (model.vars())

    # Set the model savefile
    storedir = os.path.join('models', _dataset, type(model).__name__)
    if not os.path.exists(storedir): os.makedirs(storedir)
    storefile = os.path.join(storedir, 'best_model_denoiser.npz')

    # prediction function
    reconst_loss = objax.Jit(lambda xc, xa: loss(xc, xa), model.vars())

    # loss function
    def loss(xc, xa):
        x_recon  = model(xa, training=True)
        loss_mse = objax.functional.loss.mean_squared_error(xc, x_recon).mean()
        loss_l2  = 0.5 * sum((v.value ** 2).sum() for k, v in model.vars().items() if k.endswith('.w'))
        return loss_mse + _decay_rate * loss_l2


    # load the optimizers
    gv  = objax.GradValues(loss, model.vars())
    if _optimizer == 'Adam':
        opt = objax.optimizer.Adam(model.vars())
    elif _optimizer == 'SGD':
        opt = objax.optimizer.SGD(model.vars())
    print (' : Use the optimizer - {} [noise {:.4f}]'.format(type(opt).__name__, _noise_level))


    # denoiser train
    def train_op(xc, xa, lr):
        g, v = gv(xc, xa)
        opt(lr=lr, grads=g)
        return v

    # gv.vars() contains the model variables.
    train_op = objax.Jit(train_op, gv.vars() + opt.vars())

    # local data-loader
    learning_rate = _learn_rate


    # training...
    for epoch in range(_num_epochs):
        # : update the learning rate
        if epoch in _schedulelr:
            learning_rate = learning_rate * _schedratio
            print (' : Update the learning rate to [{:.6f}]'.format(learning_rate))

        # : train for one epoch
        train_loss = train_denoiser(epoch, X_train, _num_batchs, train_op, learning_rate, noise=_noise_level)

        # : validation
        valid_batch= load_test_batch(_dataset)
        valid_loss = valid_denoiser(epoch, X_test, valid_batch, reconst_loss, noise=_noise_level)
        print(' : Loss [train {:.6f} / valid {:.6f}] w. [lr: {:.6f}]'.format(train_loss, valid_loss, learning_rate))

        # : store ...
        if valid_loss < _best_loss:
            _best_loss = valid_loss
            save_network_parameters(model, storefile)
            print (' : Store the model, to [{}]'.format(storefile))

    print (' : Done')
    # Fin.
