import numpy as np
import tensorflow as tf
#import tensorflow_datasets as tfds
from tensorflow.keras.datasets import fashion_mnist,mnist
from deel.datasets.util_generator_dataset import simple_generator, otp_generator,simple_dataset_generator
import matplotlib
import matplotlib.pyplot as plt
import random


def plot_10_by_10_images(images, filename = None):
    """ Plot 100 MNIST images in a 10 by 10 table. Note that we crop
    the images so that they appear reasonably close together.  The
    image is post-processed to give the appearance of being continued."""
    fig = plt.figure()
    #image = np.concatenate(images, axis=1)
    nb = len(images)
    for x in range(10):
        for y in range(10):
            ax = fig.add_subplot(10, 10, 10*y+x+1)
            if x<5:
                ind = random.randint(0,nb//2)
            else :
                ind = random.randint(nb//2,nb-1)
            plt.imshow(images[ind].reshape(28, 28))
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

    if filename is not None: 
        plt.savefig(
            filename,
            bbox_inches='tight'
        )
        plt.close(fig) 
    else :
        plt.show()


def fashion_mnist_dataset_oneclass(selected_classes, mnist_data = False):

    if mnist_data:
        (X_all, y_all), (X_test, y_test) = mnist.load_data()
    else :
        (X_all, y_all), (X_test, y_test) = fashion_mnist.load_data()
    X_all = X_all.reshape((-1, 28, 28, 1))
    X_test = X_test.reshape((-1, 28, 28, 1))


    selected = [y == selected_classes for y in y_all]
    X_all = X_all[selected]

    select_test = [y== selected_classes for y in y_test]
    y_b_test = np.zeros(y_test.shape)
    y_b_test[select_test] = 1



    #X_test = X_test.reshape(-1, img_rows, img_cols, nb_channel)
    X_all = X_all.astype('float32')
    X_test = X_test.astype('float32')
    X_all /= 255
    X_test /= 255
    dtset = {'X_train' : X_all,
            'X_test' : X_test,
            'Y_test' : y_b_test,
            'Y_test_true' :y_test}
    return dtset




def fashion_mnist_dataset(batch_size,to_categorical, selected_classes=None,gtValues=None, mnist_data = False):
    
    if selected_classes is None:
        selected_classes = range(10)
    nb_classes = len(selected_classes)
    if nb_classes == 2:
        nb_classes = 1 ## binary
    if gtValues is None:
        index_selected_class = {selected_classes[i]:i for i in range(len(selected_classes))}
    else:
        assert len(gtValues)==nb_classes
        index_selected_class = {selected_classes[i]:gtValues[i] for i in range(len(selected_classes))}


    print(index_selected_class)

    # the data, shuffled and split between train and test sets
    if mnist_data:
        (X_all, y_all), (X_test, y_test) = mnist.load_data()
    else :
        (X_all, y_all), (X_test, y_test) = fashion_mnist.load_data()
    X_all = X_all.reshape((-1, 28, 28, 1))
    X_test = X_test.reshape((-1, 28, 28, 1))

    print("Select only "+str(nb_classes)+" classes:"+str(selected_classes))

    select_all = [y in selected_classes for y in y_all]
    X_all = X_all[select_all]
    y_all = y_all[select_all]
    #print(y_all)
    y_all = [index_selected_class[y] for y in y_all]
    y_all = np.asarray(y_all)
    #y_all = np.reshape(y_all,(-1,1))
    max_train = int(len(X_all))

    select_test = [y in selected_classes for y in y_test]
    X_test = X_test[select_test]
    y_test = y_test[select_test]
    #print(y_test.shape)
    y_test = [index_selected_class[y] for y in y_test]
    y_test = np.asarray(y_test)

    #y_test = np.reshape(y_test,(-1,1))
    print(y_test.shape)

    #X_test = X_test.reshape(-1, img_rows, img_cols, nb_channel)
    X_all = X_all.astype('float32')
    X_test = X_test.astype('float32')
    X_all /= 255
    X_test /= 255
    #X_all = 2*X_all -1  # -1 1 range
    #X_test = 2*X_test -1  # -1 1 range

    X_train = X_all[:max_train]
    X_valid = X_all[max_train:]
    Y_train = y_all[:max_train]
    Y_valid = y_all[max_train:]
    Y_test = y_test
    if to_categorical:
        Y_test = tf.keras.utils.to_categorical(Y_test,len(selected_classes))
        Y_train = tf.keras.utils.to_categorical(Y_train,len(selected_classes))
        Y_valid = tf.keras.utils.to_categorical(Y_valid,len(selected_classes))


    print(X_train.shape[0], 'train samples')
    print(X_valid.shape[0], 'valid samples')
    print(X_test.shape[0], 'test samples')

    dtset = {'train' : simple_dataset_generator(batch_size,X_train,Y_train) , 'trainSize': X_train.shape[0],
            'valid' : simple_dataset_generator(batch_size,X_valid,Y_valid), 'validSize': X_valid.shape[0],
            'test' : simple_dataset_generator(batch_size,X_test,Y_test,shuffle=False), 'testSize': X_test.shape[0],
            'test_XY' :(X_test,Y_test),
            'batch_size': batch_size }
    return dtset