import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import tensorflow as tf
import matplotlib.pyplot as plt

from data.tf_datasets import OmniglotDataset, MiniImagenetDataset
from vae import CVAEOmniglot, CVAEMiniImageNet, _parse_function_imagenet, _parse_function_omniglot

latent_dim = 10
batch_size = 100

_parse_function = _parse_function_omniglot
dataset = OmniglotDataset()
model = CVAEOmniglot(latent_dim)
size = (28, 28)

# _parse_function = _parse_function_imagenet
# dataset = MiniImagenetDataset()
# model = CVAEMiniImageNet(latent_dim)
suze = (84, 84)


train_instances = dataset.get_train_dataset().get_all_instances()
train_dataset = tf.data.Dataset.from_tensor_slices(train_instances)
train_dataset = train_dataset.shuffle(38400)
train_dataset = train_dataset.map(_parse_function, num_parallel_calls=8)
train_dataset = train_dataset.batch(batch_size).prefetch(1)
iterator = train_dataset.make_one_shot_iterator()
image = iterator.get_next()


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # image_np = sess.run(image)
    # for i in range(100):
    #     plt.imshow(image_np[i, ...].reshape((84, 84)), cmap='gray')
    #     plt.show()


    model.load_weights(filepath='vae_models/model')

    vector = model.encode(image)
    # mean, logvar = sess.run(vector)
    mean, logvar = vector[0], vector[1]

    # eps = tf.ones(shape=(batch_size, latent_dim))

    # eps = tf.random.normal(shape=(batch_size, latent_dim))
    # eps = eps * tf.exp(logvar * .5) + mean
    eps = mean

    # eps = tf.random.uniform(shape=(), minval=-10, maxval=10) * eps * tf.exp(logvar * .5) + mean

    # eps += tf.random.uniform(shape=(batch_size, latent_dim), minval=-3.2, maxval=3.2)

    # eps += tf.random.normal(shape=(batch_size, latent_dim), mean=0, stddev=0.8)

    import numpy as np
    x = np.zeros(shape=(1, latent_dim))

    results = []
    val = -1
    for _ in range(20):
        x[0, 1] = val
        val += 0.1
        result = model.sample(eps=eps + x)
        results.append(result)

    # eps += x


    # eps += tf.random.normal(shape=(batch_size, latent_dim), mean=0, stddev=2)
    # x = np.zeros(shape=mean.shape)
    # x[0] = 1
    # eps += x
    # eps = tf.zeros(shape=mean.shape)

    result = model.sample(eps=eps)
    # result = tf.round(result)

    result_np, image_np, mean_np, logvar_np, results_np = sess.run((result, image, mean, logvar, results))
    for i in range(100):
        print(mean_np[i, ...])
        print(logvar_np[i, ...])
        plt.imshow(image_np[i, ...].reshape(size), cmap='gray')
        plt.show()
        import numpy as np
        print(np.min(image_np[i, ...]))
        print(np.max(image_np[i, ...]))
        print(np.min(result_np[i, ...]))
        print(np.max(result_np[i, ...]))
        plt.imshow(result_np[i, ...].reshape(size), cmap='gray')
        plt.show()

        fig, axs = plt.subplots(4, 5)
        row = 0
        col = 0
        for res_id in range(20):
            axs[row][col].imshow(results_np[res_id][i, ...].reshape(size), cmap='gray')
            col += 1
            if col == 5:
                row += 1
                col = 0

        plt.show()


