from os import listdir
from os.path import isfile, join

import numpy as np
import pickle
from PIL import Image
import matplotlib.pyplot as plt

def load_images_from_directory(dir_path):
    onlyfiles = [f for f in listdir(dir_path) if isfile(join(dir_path, f)) and "Store" not in f]
    onlyfiles.sort()
    dir_images = []
    for f_name in onlyfiles:
        f_name_arr = f_name.split("-")
        f_idx = int(f_name_arr[1])
        if f_idx % 1 == 0:
            full_path = join(dir_path, f_name)
            #print(full_path)
            img = Image.open(full_path)
            img1 = img.resize((128, 128), Image.ANTIALIAS)
            img_np = np.asarray(img1)
            dir_images.append(img_np)
    return dir_images

def load_two_images_from_directory(dir_path):
    onlyfiles = [f for f in listdir(dir_path) if isfile(join(dir_path, f)) and "Store" not in f]
    onlyfiles.sort()
    dir_images = []
    for f_name in onlyfiles:
        f_name_arr = f_name.split("-")
        f_idx = int(f_name_arr[1])
        if f_idx == 0 or f_idx == 100:
            img = Image.open(join(dir_path, f_name))
            img1 = img.resize((128, 128), Image.ANTIALIAS)
            img_np = np.asarray(img1)
            dir_images.append(img_np)
    return dir_images

def get_curve_length(images_list):
    d = 0.0
    for t in range(len(images_list)):
        if t > 0:
            img2 = images_list[t]
            img1 = images_list[t - 1]
            d += np.abs(np.sum(img2.flatten() - img1.flatten()))
    return d

dir_path  = "<path-to-root>/cifar10-out"
notfiles = [join(dir_path, f) for f in listdir(dir_path) if not isfile(join(dir_path, f)) and "Store" not in f]

curve_lengths = {}
line_lengths = {}

print("Starting curve lengths")

for folder in notfiles:
    folder_name = folder.split("/")[-1]
    curve_lengths[folder_name] = get_curve_length(load_images_from_directory(folder))

print("Done with curve lengths")

for folder in notfiles:
    folder_name = folder.split("/")[-1]
    line_lengths[folder_name] = get_curve_length(load_two_images_from_directory(folder))

print("Done with line lengths")

dir_path = "boundaries_data_200_cifar_epoch_1/"
onlyfiles = [(join(dir_path, f), f) for f in listdir(dir_path) if isfile(join(dir_path, f)) and "Store" not in f]
epochs = [0.0, 0.1, 0.2, 0.4, 0.5, 0.8, 0.9, 1.0]
linear_boundaries_mean = [0]*8
linear_boundaries_std = [0]*8

print("this")

linear_boundaries_mean_line = [0]*8
linear_boundaries_std_line = [0]*8
boundaries_dict_curve = {}
boundaries_dict_line = {}
for f_path, f_name in onlyfiles:
    f_arr = f_name.split("_")
    folder_name = f_arr[4]
    epoch = float(f_arr[3])

    if epoch not in boundaries_dict_curve:
        boundaries_dict_curve[epoch] = []
    with open(f_path, "rb") as f_in:
        num_boundaries, _ = pickle.load(f_in)
        num_boundaries = num_boundaries/curve_lengths[folder_name]
        boundaries_dict_curve[epoch].append(num_boundaries)

print("here2")
print(curve_lengths)
print(line_lengths)

dir_path = "boundaries_data_line_cifar_epoch_1/"
onlyfiles = [(join(dir_path, f), f) for f in listdir(dir_path) if isfile(join(dir_path, f)) and "Store" not in f]
for f_path, f_name in onlyfiles:
    f_arr = f_name.split("_")
    folder_name = f_arr[4]
    epoch = float(f_arr[3])
    if epoch not in boundaries_dict_line:
        boundaries_dict_line[epoch] = []
    with open(f_path, "rb") as f_in:
        num_boundaries, _ = pickle.load(f_in)
        num_boundaries = num_boundaries/line_lengths[folder_name]
        boundaries_dict_line[epoch].append(num_boundaries)

for i, epoch in enumerate(epochs):
    linear_boundaries_mean[i] = np.mean(np.array(boundaries_dict_curve[epoch]))
    linear_boundaries_std[i] = np.std(np.array(boundaries_dict_curve[epoch]))
    linear_boundaries_mean_line[i] = np.mean(np.array(boundaries_dict_line[epoch]))
    linear_boundaries_std_line[i] = np.std(np.array(boundaries_dict_line[epoch]))

linear_boundaries_mean = np.array(linear_boundaries_mean)
linear_boundaries_std = np.array(linear_boundaries_std)
print(linear_boundaries_mean)
linear_boundaries_mean_line = np.array(linear_boundaries_mean_line)
linear_boundaries_std_line = np.array(linear_boundaries_std_line)
print(linear_boundaries_mean_line)

plt.plot(epochs, np.log(linear_boundaries_mean), color='blue')
plt.plot(epochs, np.log(linear_boundaries_mean_line), color='green')
plt.xlabel("Epoch Number")
plt.grid()
plt.fill_between(epochs, np.log(linear_boundaries_mean - linear_boundaries_std), np.log(linear_boundaries_mean + linear_boundaries_std), color='blue', alpha=0.2)
plt.fill_between(epochs, np.log(linear_boundaries_mean_line - linear_boundaries_std_line), np.log(linear_boundaries_mean_line + linear_boundaries_std_line), color='green', alpha=0.2)
plt.savefig("cifar_lin_regions_epoch_1.png")
