import os
from datetime import datetime, timedelta

import tensorflow as tf
import numpy as np

import settings
from data.tf_datasets import OmniglotDataset, MiniImagenetDataset, AirCraftDataset
from meta_learning import ModelAgnosticMetaLearning
from models import SimpleModel, MAMLMiniImagenetModel, MAMLMiniImagenetGrayScaleModel
from vae import CVAEOmniglot, CVAEMiniImageNet


class Configuration(object):
    def __init__(
            self,
            model,
            meta_learning_dataset,
            meta_learning_val_dataset,
            meta_learning_iterations,
            meta_learning_save_after_iterations,
            meta_learning_log_after_iterations,
            meta_learning_summary_after_iterations,
            training_dataset,
            num_updates_meta_learning,
            num_updates_training,
            meta_batch_size,
            meta_learning_rate,
            learning_rate_meta_learning,
            learning_rate_training,
            n,
            k_meta_learning,
            k_training,
            meta_learning_log_dir,
            evaluation_log_dir,
            meta_learning_saving_path,
            training_saving_path=None,
            task_type='MAML_Supervised',
    ):
        self.meta_learning_iterations = meta_learning_iterations
        self.meta_learning_save_after_iterations = meta_learning_save_after_iterations
        self.meta_learning_log_after_iterations = meta_learning_log_after_iterations
        self.meta_learning_summary_after_iterations = meta_learning_summary_after_iterations
        self.model = model
        self.meta_learning_dataset = meta_learning_dataset
        self.meta_learning_val_dataset = meta_learning_val_dataset
        self.training_dataset = training_dataset
        self.num_updates_meta_learning = num_updates_meta_learning
        self.num_updates_training = num_updates_training
        self.meta_batch_size = meta_batch_size
        self.meta_learning_rate = meta_learning_rate
        self.learning_rate_meta_learning = learning_rate_meta_learning
        self.learning_rate_training = learning_rate_training
        self.n = n
        self.k_meta_learning = k_meta_learning
        self.k_training = k_training
        self.task_type = task_type

        if meta_learning_log_dir == 'config':
            self.meta_learning_log_dir = os.path.join(self.get_log_dir(), 'meta-learning')
        else:
            self.meta_learning_log_dir = meta_learning_log_dir
        if evaluation_log_dir == 'config':
            self.evaluation_log_dir = os.path.join(self.get_log_dir(), 'adaptation')
        else:
            self.evaluation_log_dir = evaluation_log_dir
        if meta_learning_saving_path == 'config':
            self.meta_learning_saving_path = os.path.join(self.get_log_dir(), 'saved_models')
        else:
            self.meta_learning_saving_path = meta_learning_saving_path

        # This part could be simplified. It is just because I wanted not to get IDE error
        # that I compare None with equality
        if training_saving_path is None:
            self.training_saving_path = None
        else:
            if training_saving_path == 'config':
                self.training_saving_path = os.path.join(self.get_log_dir(), 'saved_adaptation_models')
            else:
                self.training_saving_path = training_saving_path

    def get_log_dir(self):

        #  The model is not instantiated yet and does not have a name. We call the method with None as self in order to
        #  return the name of the model.
        log_dir = 'task-' + str(self.task_type) + '_' + \
                  'model-' + self.model.get_name(None) + '_' + \
                  'mldata-' + str(self.meta_learning_dataset) + '_' + \
                  'mlnupdates-' + str(self.num_updates_meta_learning) + '_' + \
                  'mbs-' + str(self.meta_batch_size) + '_' + \
                  'mlr-' + str(self.meta_learning_rate) + '_' + \
                  'mllr-' + str(self.learning_rate_meta_learning) + '_' + \
                  'n-' + str(self.n) + '_' + \
                  'mlk-' + str(self.k_meta_learning) + '_'
        return os.path.join(settings.DEFAULT_LOG_DIR, log_dir)

    def supervised_maml(self):
        val_tr_task, val_val_task, val_tr_labels, val_val_labels = \
            self.meta_learning_val_dataset.get_supervised_meta_learning_tasks(
                meta_batch_size=self.meta_batch_size,
                n=self.n,
                k=self.k_meta_learning,
            )

        train_task, val_task, train_labels, val_labels = self.meta_learning_dataset.get_supervised_meta_learning_tasks(
            meta_batch_size=self.meta_batch_size,
            n=self.n,
            k=self.k_meta_learning
        )
        maml = ModelAgnosticMetaLearning(
            self.model,
            train_task,
            val_task,
            train_labels,
            val_labels,
            val_dataset=(val_tr_task, val_val_task, val_tr_labels, val_val_labels),
            output_dimension=self.n,
            meta_learning_iterations=self.meta_learning_iterations,
            meta_learning_log_after_iterations=self.meta_learning_log_after_iterations,
            meta_learning_save_after_iterations=self.meta_learning_save_after_iterations,
            meta_learning_summary_after_iterations=self.meta_learning_summary_after_iterations,
            update_lr=self.learning_rate_meta_learning,
            meta_lr=self.meta_learning_rate,
            meta_batch_size=self.meta_batch_size,
            # stop_grad=False,
            num_updates=self.num_updates_meta_learning
        )

        file_writer = tf.summary.FileWriter(self.meta_learning_log_dir, tf.get_default_graph())
        validation_file_writer = tf.summary.FileWriter(os.path.join(self.meta_learning_log_dir, 'validation'))
        maml.meta_learn(file_writer, validation_file_writer, saving_path=self.meta_learning_saving_path)

    def evaluate_supervised_maml(self):
        model_address = os.path.join(self.meta_learning_saving_path)
        # model_address = os.path.join(self.meta_learning_saving_path, 'SimpleModel-25000')
        training_model = self.model(output_dimension=self.n, update_lr=self.learning_rate_training)
        train_task, val_task, train_labels, val_labels = self.training_dataset.get_supervised_meta_learning_tasks(
            meta_batch_size=1,
            n=self.n,
            k=self.k_training
        )

        # remove the 1 dimension of meta batch size
        train_labels = tf.squeeze(train_labels)
        val_labels = tf.squeeze(val_labels)

        tf.summary.image('task', tf.reshape(train_task, training_model.get_input_shape()), max_outputs=12)

        with tf.variable_scope('update/model'):
            training_model.forward(train_task)
            training_model.define_update_op(train_labels, with_batch_norm_dependency=True)
            training_model.define_accuracy(train_labels)

        for item in tf.global_variables():
            tf.summary.histogram(item.name, item)

        merged_summary = tf.summary.merge_all()

        with tf.Session() as sess:
            self.summary_evaluation_task_number = 100
            self.num_tasks_evaluation = 1000

            task_losses = np.zeros(shape=(self.num_tasks_evaluation, self.num_updates_training))
            task_accuracies = np.zeros(shape=(self.num_tasks_evaluation, self.num_updates_training))

            # sess.run(tf.global_variables_initializer())
            # training_model.save(sess, 'temp/model')

            for task in range(self.num_tasks_evaluation):
                sess.run(tf.global_variables_initializer())
                # training_model.load(sess, 'temp/model', load_last=True)
                training_model.load(sess, model_address, load_last=True)

                train_task_np, val_task_np, train_labels_np, val_labels_np = sess.run(
                    (train_task, val_task, train_labels, val_labels)
                )

                if task % self.summary_evaluation_task_number == 0:
                    print('task: {}'.format(task))

                    train_writer = tf.summary.FileWriter(
                        os.path.join(self.evaluation_log_dir, 'training', 'task-num_{}'.format(task)),
                    )
                    test_writer = tf.summary.FileWriter(
                        os.path.join(self.evaluation_log_dir, 'test', 'task-num{}'.format(task))
                    )

                for it in range(self.num_updates_training):
                    output, loss, acc, summ = sess.run(
                        (training_model.out, training_model.loss, training_model.accuracy, merged_summary),
                        feed_dict={
                            training_model.is_training: False,
                            train_task: val_task_np,
                            train_labels: val_labels_np,
                        }
                    )
                    if task % self.summary_evaluation_task_number == 0:
                        test_writer.add_summary(summ, global_step=it)

                    _, summ = sess.run((training_model.op, merged_summary), feed_dict={
                        training_model.is_training: True,
                        train_task: train_task_np,
                        train_labels: train_labels_np
                    })
                    if task % self.summary_evaluation_task_number == 0:
                        train_writer.add_summary(summ, global_step=it)

                    if task % self.summary_evaluation_task_number == 0:
                        print(np.argmax(output, 1))
                        print(loss)
                        print(acc)

                    task_losses[task, it] = loss
                    task_accuracies[task, it] = acc

            print('done')

            print('average loss:')
            print(np.mean(task_losses, axis=0))

            print('average accuracy:')
            print(np.mean(task_accuracies, axis=0))

    def umtra(self, _augment_function, model=None):
        val_tr_task, val_val_task, val_tr_labels, val_val_labels = \
            self.meta_learning_val_dataset.get_supervised_meta_learning_tasks(
                meta_batch_size=self.meta_batch_size,
                n=self.n,
                k=self.k_meta_learning,
            )

        train_task, val_task, train_labels, val_labels = self.meta_learning_dataset.get_umtra_tasks(
            meta_batch_size=self.meta_batch_size,
            n=self.n,
            augment_function=_augment_function,
        )

        maml = ModelAgnosticMetaLearning(
            self.model,
            train_task,
            val_task,
            train_labels,
            val_labels,
            val_dataset=(val_tr_task, val_val_task, val_tr_labels, val_val_labels),
            output_dimension=self.n,
            meta_learning_iterations=self.meta_learning_iterations,
            meta_learning_log_after_iterations=self.meta_learning_log_after_iterations,
            meta_learning_save_after_iterations=self.meta_learning_save_after_iterations,
            meta_learning_summary_after_iterations=self.meta_learning_summary_after_iterations,
            update_lr=self.learning_rate_meta_learning,
            meta_lr=self.meta_learning_rate,
            meta_batch_size=self.meta_batch_size,
            num_updates=self.num_updates_meta_learning,
            remember_train_task=False,
            umtra_model=model
        )

        file_writer = tf.summary.FileWriter(self.meta_learning_log_dir, tf.get_default_graph())
        validation_file_writer = tf.summary.FileWriter(os.path.join(self.meta_learning_log_dir, 'validation'))
        maml.meta_learn(file_writer, validation_file_writer, saving_path=self.meta_learning_saving_path)

    def multi_task(self):
        val_tr_task, val_val_task, val_tr_labels, val_val_labels = \
            self.meta_learning_val_dataset.get_supervised_meta_learning_tasks(
                meta_batch_size=self.meta_batch_size,
                n=self.n,
                k=self.k_meta_learning,
            )

        train_task, val_task, train_labels, val_labels = self.meta_learning_dataset.get_multi_task_tasks(
            meta_batch_size=self.meta_batch_size,
            n=self.n,
        )

        maml = ModelAgnosticMetaLearning(
            self.model,
            train_task,
            val_task,
            train_labels,
            val_labels,
            val_dataset=(val_tr_task, val_val_task, val_tr_labels, val_val_labels),
            output_dimension=self.n,
            meta_learning_iterations=self.meta_learning_iterations,
            meta_learning_log_after_iterations=self.meta_learning_log_after_iterations,
            meta_learning_save_after_iterations=self.meta_learning_save_after_iterations,
            meta_learning_summary_after_iterations=self.meta_learning_summary_after_iterations,
            update_lr=self.learning_rate_meta_learning,
            meta_lr=self.meta_learning_rate,
            meta_batch_size=self.meta_batch_size,
            num_updates=self.num_updates_meta_learning
        )

        file_writer = tf.summary.FileWriter(self.meta_learning_log_dir, tf.get_default_graph())
        validation_file_writer = tf.summary.FileWriter(os.path.join(self.meta_learning_log_dir, 'validation'))
        maml.meta_learn(file_writer, validation_file_writer, saving_path=self.meta_learning_saving_path)

    def meta_learning(self, *args, **kwargs):
        if self.task_type == 'MAML_Supervised':
            self.supervised_maml()
        elif self.task_type == 'UMTRA':
            self.umtra(kwargs['augmentation_function'], kwargs.get('model', None))
        elif self.task_type == 'TEST':
            self.multi_task()

    def evaluate(self):
        tf.reset_default_graph()
        if self.task_type == 'MAML_Supervised' or self.task_type == 'UMTRA' or self.task_type == 'TEST':
            self.evaluate_supervised_maml()

    def execute(self):
        self.meta_learning()
        self.evaluate()


if __name__ == '__main__':
    import tensorflow_hub as hub
    augmentation_module = hub.Module('https://tfhub.dev/google/image_augmentation/nas_cifar/1')
    
    def _miniimagenet_augment_function(image):
        num_imgs = 20
        image = augmentation_module(
            {
                'images': image,
                'image_size': (84, 84),
                'augmentation': True
            },
            signature='from_decoded_images'
        )

        image = tf.image.random_flip_left_right(
            image,
            seed=None
        )

        image = tf.image.crop_and_resize(
            image,
            boxes=tf.random.uniform([num_imgs, 4], minval=0, maxval=1),
            box_ind=list(range(num_imgs)),
            crop_size=(84, 84),
            method='bilinear',
            extrapolation_value=0,
            name=None
        )

        #image = tf.contrib.image.dense_image_warp(
        #    image,
        #    flow=tf.random.normal([num_imgs, 84, 84, 2], mean=0, stddev=5)
        #)
        
        image = tf.contrib.image.rotate(
            image,
            tf.random.uniform(shape=(num_imgs,), minval=-1.5, maxval=1.5),
            interpolation='NEAREST',
            name=None
        )

        #image = (image * 2.0) - 1

        base_ = tf.convert_to_tensor(np.tile([1, 0, 0, 0, 1, 0, 0, 0], [num_imgs, 1]), dtype=tf.float32)
        mask_ = tf.convert_to_tensor(np.tile([0, 0, 1, 0, 0, 1, 0, 0], [num_imgs, 1]), dtype=tf.float32)

        random_shift_ = tf.random_uniform([num_imgs, 8], minval=-18., maxval=18., dtype=tf.float32)

        transforms_ = base_ + random_shift_ * mask_
        augmented_data = tf.contrib.image.transform(images=image, transforms=transforms_)
        return augmented_data

    # def _omniglot_augment_function(image):
    #     num_imgs = 5
    #     random_map = tf.random_uniform(shape=tf.shape(image), minval=0, maxval=2, dtype=tf.int32)
    #     random_map = tf.cast(random_map, tf.float32)
    #     image = tf.minimum(image, random_map)
    #
    #     base_ = tf.convert_to_tensor(np.tile([1, 0, 0, 0, 1, 0, 0, 0], [num_imgs, 1]), dtype=tf.float32)
    #     mask_ = tf.convert_to_tensor(np.tile([0, 0, 1, 0, 0, 1, 0, 0], [num_imgs, 1]), dtype=tf.float32)
    #     negativ_random_shift_ = tf.random_uniform([num_imgs, 8], minval=-6., maxval=-3., dtype=tf.float32)
    #     positive_random_shift_ = tf.random_uniform([num_imgs, 8], minval=3., maxval=6., dtype=tf.float32)
    #
    #     random_shift_ = tf.random_uniform([num_imgs, 8], minval=-3., maxval=3., dtype=tf.float32)
    #     transforms_ = base_ + random_shift_ * mask_
    #     augmented_data = tf.contrib.image.transform(images=image, transforms=transforms_)
    #     return augmented_data


    def _omniglot_augment_function(image):
        num_imgs = 20
        random_map = tf.random_uniform(shape=tf.shape(image), minval=0, maxval=2, dtype=tf.int32)
        random_map = tf.cast(random_map, tf.float32)
        image = tf.minimum(image, random_map)

        base_ = tf.convert_to_tensor(np.tile([1, 0, 0, 0, 1, 0, 0, 0], [num_imgs, 1]), dtype=tf.float32)
        mask_ = tf.convert_to_tensor(np.tile([0, 0, 1, 0, 0, 1, 0, 0], [num_imgs, 1]), dtype=tf.float32)
        #negativ_random_shift_ = tf.random_uniform([num_imgs, 8], minval=-9., maxval=0., dtype=tf.float32)
        #positive_random_shift_ = tf.random_uniform([num_imgs, 8], minval=0., maxval=9., dtype=tf.float32)

        #random_shift_ = tf.cond(
        #    tf.random.uniform(minval=0, maxval=1.0, shape=()) > 0.5,
        #    lambda: positive_random_shift_,
        #    lambda: negativ_random_shift_
        #)
        random_shift_ = tf.random_uniform([num_imgs, 8], minval=-6., maxval=6., dtype=tf.float32)

        transforms_ = base_ + random_shift_ * mask_
        augmented_data = tf.contrib.image.transform(images=image, transforms=transforms_)
        return augmented_data


    #latent_dim = 512
    # model = CVAEOmniglot(latent_dim)

    # model = CVAEMiniImageNet(latent_dim)

    #def miniimagenet_augment_vae(image):
    #    batch_size = 5
    #    vector = model.encode(image)
    #    mean, logvar = vector[0], vector[1]
#
    #    eps = tf.random.normal(shape=(batch_size, latent_dim))
    #    eps = eps * tf.exp(logvar * .5) + mean
#
    #    # eps += tf.random.normal(shape=(batch_size, latent_dim), mean=0, stddev=0.8)
    #    result = model.sample(eps=eps)
#
    #    base_ = tf.convert_to_tensor(np.tile([1, 0, 0, 0, 1, 0, 0, 0], [5, 1]), dtype=tf.float32)
    #    mask_ = tf.convert_to_tensor(np.tile([0, 0, 1, 0, 0, 1, 0, 0], [5, 1]), dtype=tf.float32)
    #    random_shift_ = tf.random_uniform([5, 8], minval=-6., maxval=6., dtype=tf.float32)
    #    transforms_ = base_ + random_shift_ * mask_
#
    #    augmented_data = tf.contrib.image.transform(images=result, transforms=transforms_)
    #    augmented_data = tf.stop_gradient(augmented_data)
    #    return augmented_data


    def _parse_function(example_address):
        # image = tf.image.decode_jpeg(tf.read_file(example_address))
        image = tf.image.decode_jpeg(tf.io.read_file(example_address))
        image = tf.cast(image, tf.float32)

        # image = tf.image.rgb_to_grayscale(image)
        image = tf.image.resize(image, (84, 84))

        # image = tf.image.resize(image, (28, 28))
        return image / 255.

    #def omniglot_augment_vae(image):
    #    batch_size = 5
    #    vector = model.encode(image)
    #    # mean, logvar = sess.run(vector)
    #    mean, logvar = vector[0], vector[1]
#
    #    # eps = tf.ones(shape=(batch_size, latent_dim))
#
    #    eps = tf.random.normal(shape=(batch_size, latent_dim))
    #    eps = eps * tf.exp(logvar * .5) + mean
#
    #    # eps += tf.random.uniform(shape=(batch_size, latent_dim), minval=-1.6, maxval=1.6)
#
    #    # eps += tf.random.uniform(shape=(batch_size, latent_dim), minval=-0.8, maxval=0.8)
#
    #    result = model.sample(eps=eps)
    #    result = tf.round(result)
#
    #    # base_ = tf.convert_to_tensor(np.tile([1, 0, 0, 0, 1, 0, 0, 0], [5, 1]), dtype=tf.float32)
    #    # mask_ = tf.convert_to_tensor(np.tile([0, 0, 1, 0, 0, 1, 0, 0], [5, 1]), dtype=tf.float32)
    #    # random_shift_ = tf.random_uniform([5, 8], minval=-6., maxval=6., dtype=tf.float32)
    #    # transforms_ = base_ + random_shift_ * mask_
    #    # augmented_data = tf.contrib.image.transform(images=result, transforms=transforms_)
#
    #    augmented_data = result
    #    augmented_data = tf.stop_gradient(augmented_data)
    #    return augmented_data


    omniglot_dataset = OmniglotDataset()
    # miniimagenet_dataset = MiniImagenetDataset()

    config = Configuration(
        # model=MAMLMiniImagenetGrayScaleModel,
        model=SimpleModel,
        meta_learning_dataset=omniglot_dataset.get_train_dataset(),
        meta_learning_val_dataset=omniglot_dataset.get_validation_dataset(),
        meta_learning_iterations=50001,
        meta_learning_save_after_iterations=2000,
        meta_learning_log_after_iterations=100,
        meta_learning_summary_after_iterations=100,
        training_dataset=omniglot_dataset.get_test_dataset(),
        num_updates_meta_learning=1,
        num_updates_training=50,
        meta_batch_size=25,
        meta_learning_rate=0.001,
        learning_rate_meta_learning=0.05,
        learning_rate_training=0.05,
        n=20,
        k_meta_learning=1,
        k_training=1,
        meta_learning_log_dir='config',
        evaluation_log_dir='config',
        meta_learning_saving_path='config',
        training_saving_path=None,
        task_type='UMTRA',
    )

    # miniimagenet_dataset = MiniImagenetDataset()
    # config = Configuration(
    #     model=MAMLMiniImagenetModel,
    #     meta_learning_dataset=miniimagenet_dataset.get_train_dataset(),
    #     meta_learning_val_dataset=miniimagenet_dataset.get_validation_dataset(),
    #     meta_learning_iterations=2001,
    #     meta_learning_save_after_iterations=1000,
    #     meta_learning_log_after_iterations=5,
    #     meta_learning_summary_after_iterations=1,
    #     training_dataset=miniimagenet_dataset.get_test_dataset(),
    #     num_updates_meta_learning=1,
    #     num_updates_training=15,
    #     meta_batch_size=3,
    #     meta_learning_rate=0.001,
    #     learning_rate_meta_learning=0.1,
    #     learning_rate_training=0.1,
    #     n=5,
    #     k_meta_learning=1,
    #     k_training=1,
    #     meta_learning_log_dir='config',
    #     evaluation_log_dir='config',
    #     meta_learning_saving_path='config',
    #     training_saving_path=None,
    #     task_type='MAML_Supervised',
    # )

    # aircraft_dataset = AirCraftDataset()
    # config = Configuration(
    #     model=MAMLMiniImagenetModel,
    #     meta_learning_dataset=aircraft_dataset.get_train_dataset(),
    #     meta_learning_val_dataset=aircraft_dataset.get_validation_dataset(),
    #     meta_learning_iterations=40001,
    #     meta_learning_save_after_iterations=5000,
    #     meta_learning_log_after_iterations=100,
    #     meta_learning_summary_after_iterations=100,
    #     training_dataset=aircraft_dataset.get_test_dataset(),
    #     num_updates_meta_learning=5,
    #     num_updates_training=15,
    #     meta_batch_size=4,
    #     meta_learning_rate=0.001,
    #     learning_rate_meta_learning=0.05,
    #     learning_rate_training=0.05,
    #     n=5,
    #     k_meta_learning=1,
    #     k_training=1,
    #     meta_learning_log_dir='config',
    #     evaluation_log_dir='config',
    #     meta_learning_saving_path='config',
    #     training_saving_path=None,
    #     task_type='MAML_Supervised',
    # )

    beginning_time = datetime.now()

    # config.meta_learning()
    # config.meta_learning(augmentation_function=miniimagenet_augment_vae, model=model)

    config.meta_learning(augmentation_function=_omniglot_augment_function, model=None)

    end_time = datetime.now()
    config.evaluate()
    print('meta-learning time')
    print(beginning_time)
    print(end_time)
    print(end_time - beginning_time)
