#!/usr/bin/env python
# coding: utf-8

import pandas as pd
import random as rn
import os
# os.environ['PYTHONHASHSEED'] = '0'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import numpy as np

import sys
sys.path.append('./src/')
from Utils import makeDir,readModelConfigurations

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# np.random.seed(37)
# rn.seed(1254)
# tf.set_random_seed(89)

all_history=pd.DataFrame()
mean_squared_error = []
val_mean_squared_error = []


# In[3]:


def make_cv_data(dataset_path, n_folds, data_dir,fold_data):
    #Read data
    column_names = ['ID','TOWN','Town No','Tract','Lon','Lat','Medv','CMEDV','CRIM','ZN'
                   ,'INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','LSTAT'] 
    raw_dataset = pd.read_csv(dataset_path, names=column_names,
                          na_values = "?", comment='\t',
                          sep="\t", skipinitialspace=True)
    dataset = raw_dataset.copy()
    dataset = dataset.dropna()
    
    dataset.drop("ID", axis=1, inplace=True,)
    dataset.drop("TOWN", axis=1, inplace=True,)
    dataset.drop("Town No", axis=1, inplace=True,)
    dataset.drop("Tract", axis=1, inplace=True,)
    dataset.drop("Lon", axis=1, inplace=True,)
    dataset.drop("Lat", axis=1, inplace=True,)
    dataset.drop("Medv", axis=1, inplace=True,)
    
    min_max_dict = getMinMaxRangeOfFeatures(dataset)
    # prepare the k-fold cross-validation configuration
    train_dataset = []
    train_labels = []
    test_dataset= []
    test_labels= []
    #make train and test
    for i in range(n_folds):
        print(data_dir)
        makeDir(data_dir+'%d/'%(i))
        train_path = fold_data + '%d/train_data.csv'%(i)
        test_path = fold_data + '%d/test_data.csv'%(i)

        if (os.path.isfile(train_path)):
            trainX = pd.read_csv(train_path, index_col=0)
            testX = pd.read_csv(test_path, index_col=0)
        else:
            testX = dataset.sample(frac=0.20)
            trainX = dataset.drop(testX.index)
            trainX.to_csv(train_path)
            testX.to_csv(test_path) 
            
        train_dataset.append(trainX)
        test_dataset.append(testX)

        trainY = trainX.pop('CMEDV')
        testY = testX.pop('CMEDV')
        train_labels.append(trainY)
        test_labels.append(testY)

    return train_dataset, train_labels, test_dataset, test_labels, min_max_dict


# In[4]:


def get_bounds():
    #TODO: this could be generated from the data: here we hard code this information
    #x1 = Cylinders
    #x2 = Displacement
    #x3 = Horsepower
    #x4 = Weight
    #x5 = Acceleration
    #x6 = Model Year
    #x7 = Origin
    return [(0,10),(0,500),(0,300),(0,6000),(0,30), (0,2500), (1,3)]
    #return [(0,1),(0,1),(0,1),(0,1),(0,1),(0,1),(0,1)]


# In[5]:


def build(train_dataset, layer_size,hidden_size,learning_rate):
    layer_array = []
    for i in range(layer_size-1):
        if i == 0:
            layer_array.append(layers.Dense(hidden_size, activation=tf.nn.relu, input_shape=[len(train_dataset.keys())]))
        else:
            layer_array.append(layers.Dense(hidden_size, activation=tf.nn.relu))
    
    layer_array.append(layers.Dense(1))
    model = keras.Sequential(layer_array)
    
    optimizer = keras.optimizers.Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    model.compile(loss='mse',
                optimizer=optimizer,
                metrics=['mse','mae'])
    activation_types =[]
    for i in range(layer_size-1):
        activation_types.append('relu')
    activation_types.append('linear')
    return model, activation_types, layer_size


# In[6]:


#MIP_model is an array of layers where each layer is a tuple of (#activation (e.g., relu or linear), weight, bias)
def train(train_dataset, train_labels, fold, data_dir, layer_size,hidden_size,epoch,batch_size,learning_rate,isInitial=False,initialDir=""):
    #build the model
    if isInitial :
        model_file = initialDir+"model.h5"
        data_dir = initialDir
    else:
        model_file = data_dir+"model.h5"
    model, activation_types, layer_size = build(train_dataset,layer_size,hidden_size,learning_rate)
    callbacks = [EarlyStopping(monitor='val_loss', patience=1000),
             ModelCheckpoint(filepath=data_dir+'best_model.h5', monitor='val_loss', save_best_only=True)]

    MIP_model = []
    
    
    if (os.path.isfile(model_file)):
        print("Loading from model file from "+str(model_file))
        model.load_weights(model_file)
    else:
        model.fit(train_dataset, train_labels, epochs=epoch, batch_size=batch_size, callbacks=callbacks, validation_split = 0.2, verbose=0)
        #save the learned model
        model.save_weights(data_dir+"model.h5")
    for i in range(layer_size):
        weight = model.layers[i].get_weights()[0]
        bias = model.layers[i].get_weights()[1]
        np.savetxt(data_dir+"weights_layer%d.csv"%(i),weight,delimiter=",")
        np.savetxt(data_dir+"bias_layer%d.csv"%(i),bias,delimiter=",")
        MIP_model.append((activation_types[i], weight, bias))
    return model, MIP_model


# In[7]:


def evaluate(model, test_dataset,test_labels):
    scores = model.evaluate(test_dataset, test_labels, verbose=0)
    #TODO: score[0] is the loss of the model => mean_squared_error
    return scores[0]


# In[8]:

def update_batch (model, mip_model,batch_data,batch_label,fold,data_dir,batch_size):
    history=model.fit(batch_data, batch_label, epochs=1, batch_size=64, validation_split = 0.2, verbose=0)
    layer_size = len(mip_model)
    updated_MIP_model = []
    # hidden_size = 64
    print("Saving model file to "+data_dir)
    model.save_weights(data_dir+"model.h5")
    for i in range(layer_size):
        weight = model.layers[i].get_weights()[0]
        bias = model.layers[i].get_weights()[1]
        np.savetxt(data_dir+"/weights_layer%d.csv"%(i),weight,delimiter=",")
        np.savetxt(data_dir+"/bias_layer%d.csv"%(i),bias,delimiter=",")
        activation_type = mip_model[i][0]
        updated_MIP_model.append((activation_type, weight, bias))
    return model, updated_MIP_model

def update(model, mip_model, batch_data, batch_label,fold):
    history=model.fit(batch_data, batch_label, epochs=10, batch_size=64, validation_split = 0.2, verbose=1)
    plot_history(history)
    layer_size = len(mip_model)
    updated_MIP_model = []
    hidden_size = 64
    for i in range(layer_size):
        weight = model.layers[i].get_weights()[0]
        bias = model.layers[i].get_weights()[1]
        np.savetxt("./Data/boston_r/%d/weights_layer%d.csv"%(fold,i),weight,delimiter=",")
        np.savetxt("./Data/boston_r/%d/bias_layer%d.csv"%(fold,i),bias,delimiter=",")
        activation_type = mip_model[i][0]
        updated_MIP_model.append((activation_type, weight, bias))
    return model, updated_MIP_model


# In[9]:


def output(model, datapoint):
    column_names = ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','LSTAT'] 
    x_point = pd.DataFrame(datapoint)#, columns= column_names)
    return model.predict(x_point.transpose())


# In[10]:


def plot_history(history):
  hist = pd.DataFrame(history.history)
  global mean_squared_error
  global val_mean_squared_error

  mean_squared_error.extend(hist['mean_squared_error'].tolist())
  val_mean_squared_error.extend(hist['val_mean_squared_error'].tolist())
  print(mean_squared_error)
  epoch = list(range(0, len(mean_squared_error)))
  print(epoch)
  plt.figure()
  plt.xlabel('Epoch')
  plt.ylabel('Mean Square Error [$MPG^2$]')
  plt.plot(epoch, mean_squared_error,
           label='Train Error')
  plt.plot(epoch, val_mean_squared_error,
           label = 'Val Error')
  plt.ylim([0,60])
  plt.legend()
  plt.show()


# In[11]:

def getMinMaxRangeOfFeatures(dataset):
    column_names = ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','LSTAT']
    min_max = {}
    for i in column_names:
        index = column_names.index(i)
        print(max(dataset[column_names[index]][1:]))
        min_max[i] = [min(dataset[column_names[index]]),max(dataset[column_names[index]][1:])]
    return min_max

def getMinMaxRange(dataset,index):
    column_names = ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B','LSTAT'] 
    return min(dataset[column_names[index]]), max(dataset[column_names[index]]),column_names[index]

def generateModel(data_path, data_dir, n_folds, fold_no, layer_size, hidden_size, epoch,batch_size,learning_rate, isInitial, initialDir):
    train_dataset, train_labels, test_dataset, test_labels,min_max_dict = make_cv_data(data_path,n_folds,data_dir,data_dir)
    if isInitial :
        makeDir(initialDir,True)
    model, MIP_model = train(train_dataset[fold_no],train_labels[fold_no],fold_no,data_dir,layer_size, hidden_size,epoch,batch_size,learning_rate,isInitial,initialDir)