import os
import sys
import linear_relax as LP_relax_file
import ip_model_whole as ip_model_whole_file
from ip_model_whole import IPOfunc
import numpy as np
import random
import pandas as pd
import math, time
import itertools
from sklearn import preprocessing
from sklearn.preprocessing import MinMaxScaler
import datetime
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.utils.data as data_utils
from torch.utils.data.dataset import Dataset
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
import gurobipy as gp
import logging
import copy
from collections import defaultdict
import joblib
import gurobipy as gp
from gurobipy import GRB

item_num = LP_relax_file.item_num
month_num = LP_relax_file.month_num
total_month_num = LP_relax_file.month_num
x_num = LP_relax_file.x_num
y_num = LP_relax_file.y_num
z_num = LP_relax_file.z_num
var_num = LP_relax_file.var_num

cap = int(sys.argv[1])
trans_fee_percent = float(sys.argv[2])
LP_relax_file.set_capacity(cap)
LP_relax_file.set_trans_fee_percent(trans_fee_percent)
startmark = int(sys.argv[3])
endmark = int(sys.argv[4])

featureNum = 4096
train_set_size = 70
target_num = 2
warm_start_epoch = 15
stop_epoch_criterion = 20
log_regularizer = 1e-8
warm_start_value = 300

dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'data/trans_fee=' + str(trans_fee_percent) +'/v' + str(LP_relax_file.version_num) + '(item_num=' + str(item_num) + ',month_num=' + str(total_month_num) + ',trans_fee=' + str(trans_fee_percent) + ',cap=' + str(LP_relax_file.capacity) + ')/')
LP_relax_file.mkdir(default_path, 'MS, warm_start=' + str(warm_start_epoch))
    
def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
        
def make_fc(num_layers, num_features, num_targets=target_num,
            activation_fn = nn.ReLU,intermediate_size=512, regularizers = True):
    net_layers = [nn.Linear(num_features, intermediate_size),activation_fn()]
    for hidden in range(num_layers-2):
        net_layers.append(nn.Linear(intermediate_size, intermediate_size))
        net_layers.append(activation_fn())
    net_layers.append(nn.Linear(intermediate_size, num_targets))
    net_layers.append(activation_fn())
    return nn.Sequential(*net_layers)
        

class MyCustomDataset():
    def __init__(self, feature, value):
        self.feature = feature
        self.value = value

    def __len__(self):
        return len(self.value)

    def __getitem__(self, idx):
        return self.feature[idx], self.value[idx]


class Intopt:
    def __init__(self, n_features, num_layers=5, smoothing=False, thr=0.1, max_iter=None, method=1, mu0=None, damping=0.5, target_size=target_num, epochs=8, optimizer=optim.Adam, batch_size=x_num, **hyperparams):
        
        self.target_size = target_size
        self.n_features = n_features
        self.damping = damping
        self.num_layers = num_layers

        self.smoothing = smoothing
        self.thr = thr
        self.max_iter = max_iter
        self.method = method
        self.mu0 = mu0

        self.optimizer = optimizer
        self.batch_size = batch_size
        self.hyperparams = hyperparams
        self.epochs = epochs
        # print("embedding size {} n_features {}".format(embedding_size, n_features))

#        self.model = Net(n_features=n_features, target_size=target_size)
        self.model = make_fc(num_layers=self.num_layers,num_features=n_features)
        #self.model.apply(weight_init)
#        w1 = self.model[0].weight
#        print(w1)

        self.optimizer = optimizer(self.model.parameters(), **hyperparams)

    def fit(self, feature, value):
        logging.info("Intopt")
        train_df = MyCustomDataset(feature, value)

        criterion = nn.L1Loss(reduction='mean')  # nn.L1Loss(reduction='mean')
        grad_list = np.zeros(self.epochs)
        IP_grad_list = np.zeros(self.epochs)
        for i in range(self.epochs):
                IP_grad_list[i] = float("inf")
        for e in range(self.epochs):
            total_loss = 0
#          for parameters in self.model.parameters():
#            print(parameters)
            if e < warm_start_epoch:
                train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                for feature, value in train_dl:
                    self.optimizer.zero_grad()
                    op = self.model(feature).squeeze()
                    loss = criterion(op, value)
                    total_loss += loss.item()
                    loss.backward()
                    self.optimizer.step()

                grad_list[e] = total_loss
                global stop_epoch
                stop_epoch = e
                # print("Epoch{} ::loss {} ->".format(e,total_loss))
                if e < warm_start_epoch - 1:
                  print("{} ->".format(total_loss), end=" ")
                else:
                  print("{} ->".format(total_loss))
                
                global warm_start_value
                if e == warm_start_epoch - 1 and grad_list[e] < warm_start_value:
                    self.model.eval()
                    criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
                    valid_df = MyCustomDataset(feature, value)
                    valid_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                    corr_obj_list = []
                    true_obj_list = []
                    num = 0
                    for feature, value in valid_dl:
                        op = self.model(feature).squeeze()
                        # print(op)
                        loss = criterion(op, value)

                        true_price = np.zeros(x_num)
                        pred_price = np.zeros(x_num)
                        true_weight = np.zeros(x_num)
                        pred_weight = np.zeros(x_num)
                        for i in range(x_num):
                            true_price[i] = value[i][0]
                            pred_price[i] = op[i][0]
                            true_weight[i] = value[i][1]
                            pred_weight[i] = op[i][1]
                        
                        true_obj = LP_relax_file.actual_obj(true_price, true_weight, n_instance=1)
                        true_obj_list.append(true_obj)
                        corrrlst = LP_relax_file.correction_single_obj(pred_price, pred_weight, true_price, true_weight)
                        corr_obj_list.append(corrrlst)
                        num = num + 1
                    
                    true_obj_np = np.array(true_obj_list)
                    corr_obj_np = np.array(corr_obj_list)
                    IP_grad_list[e] = np.mean(corr_obj_np)
#                    print(num)
                    print("TOV: ", np.mean(true_obj_np), "EOV: ", np.mean(corr_obj_np), "PReg: ", np.mean(true_obj_np - corr_obj_np))
            
            else:
                if e == warm_start_epoch:
                    lr = 1e-6
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                # print(lr)
                train_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                
                num = 0
                batchCnt = 0
                loss = Variable(torch.tensor(0.0, dtype=torch.double), requires_grad=True)
                for feature, value in train_dl:
                    self.optimizer.zero_grad()
                    op = self.model(feature).squeeze()
                    while torch.min(op) < 0 or torch.isnan(op).any() or torch.isinf(op).any():
                        self.optimizer.zero_grad()
    #                    self.model.__init__(self.n_features, self.target_size)
                        self.model = make_fc(num_layers=self.num_layers,num_features=self.n_features)
                        op = self.model(feature).squeeze()
      

                    true_price = value[:, 0]
                    true_weight = value[:, 1]
#                    pred_price = op[:, 0]
#                    pred_weight = op[:, 1]
    #                print(true_price, true_Weight)
                    
                    sol_cur = IPOfunc(true_price=true_price, true_weight=true_weight, max_iter=self.max_iter, thr=self.thr, damping=self.damping,
                            smoothing=self.smoothing)(op)
                    
                    x_sol_cur = sol_cur[:x_num]
                    y_sol_cur = sol_cur[x_num:x_num+y_num]
                    z_sol_cur = sol_cur[x_num+y_num:]
                    newLoss = - (((true_price * x_sol_cur).sum() + (true_weight[item_num*(month_num-1):] * x_sol_cur[item_num*(month_num-1):]).sum() - (true_weight[:item_num] * x_sol_cur[:item_num]).sum() - (true_weight[item_num:] * y_sol_cur).sum() - (trans_fee_percent * true_weight[item_num:] * z_sol_cur).sum()))
                    EOV_IP_value = newLoss.item()
                    total_loss += EOV_IP_value
                    newLoss.backward()
                    self.optimizer.step()
                    

                    batchCnt = batchCnt + 1
#                    total_loss += newLoss.item()
                    
                    # when training size is large
#                    if batchCnt % 30 == 0:
#                        print(EOV_IP_value)
                    num = num + 1

                grad_list[e] = total_loss / num
                stop_epoch = e

                # compute IP_grad
                valid_dl = data_utils.DataLoader(train_df, batch_size=self.batch_size, shuffle=False)
                corr_obj_list = []
                true_obj_list = []
                num = 0
                for feature, value in valid_dl:
                    op = self.model(feature).squeeze()
                    #            print(op)
                    loss = criterion(op, value)

                    true_price = np.zeros(x_num)
                    pred_price = np.zeros(x_num)
                    true_weight = np.zeros(x_num)
                    pred_weight = np.zeros(x_num)
                    for i in range(x_num):
                        true_price[i] = value[i][0]
                        pred_price[i] = op[i][0]
                        true_weight[i] = value[i][1]
                        pred_weight[i] = op[i][1]

                    true_obj = LP_relax_file.actual_obj(true_price, true_weight, n_instance=1)
                    true_obj_list.append(true_obj)
                    corrrlst = LP_relax_file.correction_single_obj(pred_price, pred_weight, true_price, true_weight)
                    corr_obj_list.append(corrrlst)
                    num = num + 1

                true_obj_np = np.array(true_obj_list)
                corr_obj_np = np.array(corr_obj_list)
                IP_grad_list[e] = np.mean(corr_obj_np)

            logging.info("EPOCH Ends")
            if e >= warm_start_epoch:
                print("Epoch{} ::EOV {} ".format(e, IP_grad_list[e]))
            #print("Epoch{}".format(e))
            #          for param_group in self.optimizer.param_groups:
            #            print(param_group['lr'])
            if grad_list[6] > warm_start_value:
                break
            if e >= 1 and abs(grad_list[e] - grad_list[e-1]) <= 0.001:
                break
            if e >= warm_start_epoch and abs(IP_grad_list[e] - IP_grad_list[e-1]) <= 0.001:
                break
            if e >= warm_start_epoch and abs(IP_grad_list[e]) < abs(IP_grad_list[e-1]):
                break

            

    def val_loss(self, feature, value):
        valueTemp = value.numpy()
#        test_instance = len(valueTemp) / self.batch_size
        instance_num = np.size(valueTemp, 0) / self.batch_size
#        print(valueTemp.shape, instance_num)
        true_price = valueTemp[:, 0]
        true_weight = valueTemp[:, 1]
#        print(true_price.shape, true_weight.shape)
        true_obj = LP_relax_file.actual_obj(true_price, true_weight, n_instance=int(instance_num))
#        print(np.sum(real_obj))

        self.model.eval()
        criterion = nn.L1Loss(reduction='mean')  # nn.MSELoss(reduction='sum')
        valid_df = MyCustomDataset(feature, value)
        valid_dl = data_utils.DataLoader(valid_df, batch_size=self.batch_size, shuffle=False)

        corr_obj_list = []
        len = np.size(valueTemp, 0)
        predVal = torch.zeros((len, 2))
        
        num = 0
        for feature, value in valid_dl:
            op = self.model(feature).squeeze()
#            print(op)
            loss = criterion(op, value)

            true_price = np.zeros(x_num)
            pred_price = np.zeros(x_num)
            true_weight = np.zeros(x_num)
            pred_weight = np.zeros(x_num)
            for i in range(x_num):
                true_price[i] = value[i][0]
                pred_price[i] = op[i][0]
                true_weight[i] = value[i][1]
                pred_weight[i] = op[i][1]
                predVal[i+num*x_num][0] = op[i][0]
                predVal[i+num*x_num][1] = op[i][1]

            corrrlst = LP_relax_file.correction_single_obj(pred_price, pred_weight, true_price, true_weight)
            corr_obj_list.append(corrrlst)
            num = num + 1
            
            
        print("TOV: ", sum(true_obj)/num, "EOV: ", sum(corr_obj_list)/num, "PReg: ", sum(abs(true_obj) - np.array(corr_obj_list))/num)
#        print(corr_obj_list)
#        print(corr_obj_list-real_obj)
#        print(np.sum(corr_obj_list))
#        return prediction_loss, abs(np.array(obj_list) - real_obj)
        return abs(np.array(corr_obj_list) - true_obj), predVal



print("*** Baseline ****")

simulation_time = 30
recordBest = np.zeros((1, simulation_time))
print("item_num: ", item_num, "month_num: ", month_num, "trans_fee_percent: ", trans_fee_percent, "capacity: ", LP_relax_file.capacity, "warm_start_epoch： ", warm_start_epoch)

for testi in range(startmark, endmark):
    print(testi)
    stop_epoch = 0
    x_train = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/train_features/train_features(' + str(testi) + ').txt'))
    y_train1 = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/rescale_train_prices/rescale_train_prices(' + str(testi) + ').txt'))
    y_train2 = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/train_weights/train_weights(' + str(testi) + ').txt'))

    y_train = np.zeros((y_train1.size, 2))
    for i in range(y_train1.size):
        y_train[i][0] = y_train1[i]
        y_train[i][1] = y_train2[i]
    feature_train = torch.from_numpy(x_train).float()
    value_train = torch.from_numpy(y_train).float()
    
    
    x_test = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/test_features/test_features(' + str(testi) + ').txt'))
    y_test1 = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/rescale_test_prices/rescale_test_prices(' + str(testi) + ').txt'))
    y_test2 = np.loadtxt(os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(total_month_num) + '/test_weights/test_weights(' + str(testi) + ').txt'))

    y_test = np.zeros((y_test1.size, 2))
    for i in range(y_test1.size):
        y_test[i][0] = y_test1[i]
        y_test[i][1] = y_test2[i]
    feature_test = torch.from_numpy(x_test).float()
    value_test = torch.from_numpy(y_test).float()
    
    start = time.time()
    damping = 1e-2
    thr = 1e-3
    lr = 1e-5
    bestTrainCorrReg = float("inf")
    while stop_epoch < warm_start_epoch:
        clf = Intopt(damping=damping, lr=lr, n_features=featureNum, thr=thr, epochs=stop_epoch_criterion)
        clf.fit(feature_train, value_train)

        if stop_epoch >= warm_start_epoch:
            end = time.time()
            train_rslt, predTrainVal = clf.val_loss(feature_train, value_train)
            avgTrainCorrReg = np.mean(train_rslt)
            trainHSD_rslt = 'train: ' + str(np.mean(train_rslt))
            bestTrainCorrReg = avgTrainCorrReg
            torch.save(clf.model.state_dict(), 'MS_cap' + str(cap) + '_trans' + str(trans_fee_percent) + '_intOpt_model.pkl')
            print(trainHSD_rslt)


    clfBest = Intopt(damping=damping, lr=lr, n_features=featureNum, thr=thr, epochs=stop_epoch_criterion)
    clfBest.model.load_state_dict(torch.load('MS_cap' + str(cap) + '_trans' + str(trans_fee_percent) + '_intOpt_model.pkl'))

    val_rslt, predTestVal = clfBest.val_loss(feature_test, value_test)
    # end = time.time()

    predTestVal = predTestVal.detach().numpy()
#    print(predTestVal.shape)
    predTestVal1 = predTestVal[:, 0]
    predTestVal2 = predTestVal[:, 1]
    predValuePrice = np.zeros((predTestVal1.size, 2))
    for i in range(predTestVal1.size):
#        predValue[i][0] = int(i/itemNum)
        predValuePrice[i][0] = y_test1[i]
        predValuePrice[i][1] = predTestVal1[i]
    np.savetxt(os.path.join(default_path, 'MS, warm_start=' + str(warm_start_epoch) + '/MS_prices(' + str(testi) + ').txt'), predValuePrice, fmt="%.2f")
    
    predValueWeight = np.zeros((predTestVal2.size, 2))
    for i in range(predTestVal2.size):
#        predValue[i][0] = int(i/itemNum)
        predValueWeight[i][0] = y_test2[i]
        predValueWeight[i][1] = predTestVal2[i]
    np.savetxt(os.path.join(default_path, 'MS, warm_start=' + str(warm_start_epoch) + '/MS_weights(' + str(testi) + ').txt'), predValueWeight, fmt="%.2f")
    
    HSD_rslt = 'test: ' + str(np.mean(val_rslt)) + ' MSE: ' + str(mean_squared_error(y_test, predTestVal))
    print(HSD_rslt, end=" ")
    print ('Elapsed time: ' + str(end-start))
    recordBest[0][testi] = np.sum(val_rslt)

print(recordBest)
