'''Script for packaging the COOS-7 images as h5py files, in the format loaded by DeepLoc.'''

import os
import h5py
import numpy as np
from PIL import Image
from sklearn.preprocessing import LabelEncoder, LabelBinarizer
from sklearn.model_selection import StratifiedShuffleSplit

# Base directory where the raw tif files are stored, with each dataset as a subdirectory
basedir = "./COOS7/"

# Encode classes as labels (each class is a subdirectory of dataset) using training dataset
traindir = basedir + "train/"
classes = os.listdir(traindir)
le = LabelEncoder().fit(classes)
le_name_mapping = (np.stack((le.classes_, le.transform(le.classes_)), axis=-1)).astype(dtype=h5py.special_dtype(vlen=str))
enc = LabelBinarizer().fit(range(7))
axis = 0

# Load training dataset and labels
train_images = []
train_labels = []
for class_folder in os.listdir(traindir):
    label = le.transform([class_folder])[0]
    counter = 0
    for image_name in os.listdir(traindir + class_folder):
        if "_protein.tif" in image_name:
            counter += 1
            gfp = np.array(Image.open(traindir + class_folder + "/" + image_name))
            nuc = np.array(Image.open(traindir + class_folder + "/" + image_name.replace("_protein.tif", "_nucleus.tif")))
            img = np.stack((gfp, nuc), axis=axis)
            train_images.append(img)
            train_labels.append(label)
    print (class_folder, counter)
train_images = np.array(train_images)
train_labels = np.array(train_labels)
print (train_images.shape, train_labels.shape)

# Load test1 dataset and labels
test1dir = basedir + "test1/"
test1_images = []
test1_labels = []
for class_folder in os.listdir(test1dir):
    label = le.transform([class_folder])[0]
    counter = 0
    for image_name in os.listdir(test1dir + class_folder):
        if "_protein.tif" in image_name:
            counter += 1
            gfp = np.array(Image.open(test1dir + class_folder + "/" + image_name))
            nuc = np.array(Image.open(test1dir + class_folder + "/" + image_name.replace("_protein.tif", "_nucleus.tif")))
            img = np.stack((gfp, nuc), axis=axis)
            test1_images.append(img)
            test1_labels.append(label)
    print(class_folder, counter)
test1_images = np.array(test1_images)
test1_labels = np.array(test1_labels)
print (test1_images.shape, test1_labels.shape)

# Load test2 dataset and labels
test2dir = basedir + "test2/"
test2_images = []
test2_labels = []
for class_folder in os.listdir(test2dir):
    label = le.transform([class_folder])[0]
    counter = 0
    for image_name in os.listdir(test2dir + class_folder):
        if "_protein.tif" in image_name:
            counter += 1
            gfp = np.array(Image.open(test2dir + class_folder + "/" + image_name))
            nuc = np.array(Image.open(test2dir + class_folder + "/" + image_name.replace("_protein.tif", "_nucleus.tif")))
            img = np.stack((gfp, nuc), axis=axis)
            test2_images.append(img)
            test2_labels.append(label)
    print(class_folder, counter)
test2_images = np.array(test2_images)
test2_labels = np.array(test2_labels)
print (test2_images.shape, test2_labels.shape)

# Load test3 dataset and labels
test3dir = basedir + "test3/"
test3_images = []
test3_labels = []
for class_folder in os.listdir(test3dir):
    label = le.transform([class_folder])[0]
    counter = 0
    for image_name in os.listdir(test3dir + class_folder):
        if "_protein.tif" in image_name:
            counter += 1
            gfp = np.array(Image.open(test3dir + class_folder + "/" + image_name))
            nuc = np.array(Image.open(test3dir + class_folder + "/" + image_name.replace("_protein.tif", "_nucleus.tif")))
            img = np.stack((gfp, nuc), axis=axis)
            test3_images.append(img)
            test3_labels.append(label)
    print(class_folder, counter)
test3_images = np.array(test3_images)
test3_labels = np.array(test3_labels)
print (test3_images.shape, test3_labels.shape)

# Load test4 dataset and labels
test4dir = basedir + "test4/"
test4_images = []
test4_labels = []
for class_folder in os.listdir(test4dir):
    label = le.transform([class_folder])[0]
    counter = 0
    for image_name in os.listdir(test4dir + class_folder):
        if "_protein.tif" in image_name:
            counter += 1
            gfp = np.array(Image.open(test4dir + class_folder + "/" + image_name))
            nuc = np.array(Image.open(test4dir + class_folder + "/" + image_name.replace("_protein.tif", "_nucleus.tif")))
            img = np.stack((gfp, nuc), axis=axis)
            test4_images.append(img)
            test4_labels.append(label)
    print(class_folder, counter)
test4_images = np.array(test4_images)
test4_labels = np.array(test4_labels)
print (test4_images.shape, test4_labels.shape)

# Shuffle the indices of each dataset
indices = np.arange(train_images.shape[0])
np.random.shuffle(indices)
train_images = train_images[indices]
train_labels = train_labels[indices]
indices = np.arange(test1_images.shape[0])
np.random.shuffle(indices)
test1_images = test1_images[indices]
test1_labels = test1_labels[indices]
indices = np.arange(test2_images.shape[0])
np.random.shuffle(indices)
test2_images = test2_images[indices]
test2_labels = test2_labels[indices]
indices = np.arange(test3_images.shape[0])
np.random.shuffle(indices)
test3_images = test3_images[indices]
test3_labels = test3_labels[indices]
indices = np.arange(test4_images.shape[0])
np.random.shuffle(indices)
test4_images = test4_images[indices]
test4_labels = test4_labels[indices]

# Split training dataset into train and validation dataset (80/20 split) and save
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
sss.get_n_splits(train_images, train_labels)
for train_index, test_index in sss.split(train_images, train_labels):
    with h5py.File("./COOS7_train_80.hdf5", "w") as f:
        f.create_dataset("data1", data=train_images[train_index].reshape(train_images[train_index].shape[0],
                                                                         train_images.shape[1] *
                                                                         train_images.shape[2] *
                                                                         train_images.shape[3]))
        f.create_dataset("Index1", data=enc.transform(train_labels[train_index]))

    with h5py.File("D:/COOS7_train_20.hdf5", "w") as f:
        f.create_dataset("data1", data=train_images[test_index].reshape(train_images[test_index].shape[0],
                                                                        train_images.shape[1] *
                                                                        train_images.shape[2] *
                                                                        train_images.shape[3]))
        f.create_dataset("Index1", data=enc.transform(train_labels[test_index]))

# Save all test datasets
with h5py.File("./COOS7_test1.hdf5", "w") as f:
    f.create_dataset("data1", data=test1_images.reshape(test1_images.shape[0],
                                                        test1_images.shape[1] *
                                                        test1_images.shape[2] *
                                                        test1_images.shape[3]))
    f.create_dataset("Index1", data=enc.transform(test1_labels))

with h5py.File("./COOS7_test2.hdf5", "w") as f:
    f.create_dataset("data1", data=test2_images.reshape(test2_images.shape[0],
                                                        test2_images.shape[1] *
                                                        test2_images.shape[2] *
                                                        test2_images.shape[3]))
    f.create_dataset("Index1", data=enc.transform(test2_labels))

with h5py.File("./COOS7_test3.hdf5", "w") as f:
    f.create_dataset("data1", data=test3_images.reshape(test3_images.shape[0],
                                                        test3_images.shape[1] *
                                                        test3_images.shape[2] *
                                                        test3_images.shape[3]))
    f.create_dataset("Index1", data=enc.transform(test3_labels))

with h5py.File("./COOS7_test4.hdf5", "w") as f:
    f.create_dataset("data1", data=test4_images.reshape(test4_images.shape[0],
                                                        test4_images.shape[1] *
                                                        test4_images.shape[2] *
                                                        test4_images.shape[3]))
    f.create_dataset("Index1", data=enc.transform(test4_labels))