import os
import sys
import linear_relax as LP_relax_file
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time
from numpy import linalg as LA

item_num = LP_relax_file.item_num
month_num = LP_relax_file.month_num
x_num = LP_relax_file.x_num
y_num = LP_relax_file.y_num
z_num = LP_relax_file.z_num
var_num = LP_relax_file.var_num


cap = int(sys.argv[1])
trans_fee_percent = float(sys.argv[2])
LP_relax_file.set_capacity(cap)
LP_relax_file.set_trans_fee_percent(trans_fee_percent)
startmark = int(sys.argv[3])
endmark = int(sys.argv[4])


instance_num = 30
methodList = ['Ridge', 'knn', 'CART', 'RF']

dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'data/item_num=5, month_num=' + str(month_num) + '/')
print("item_num: ", item_num, " month_num: ", month_num, " trans_fee_percent: ", trans_fee_percent, " capacity: ", LP_relax_file.capacity)

for methodName in methodList:
    print(methodName)
    for testmark in range(startmark, endmark):
        start_time = time.time()
        price_temp = np.loadtxt(os.path.join(default_path, methodName + '/' + methodName + '_prices(' + str(testmark) + ').txt'))
        weight_temp = np.loadtxt(os.path.join(default_path, methodName + '/' + methodName + '_weights(' + str(testmark) + ').txt'))
    
        true_price_total = price_temp[:, 0]
        pred_price_total = price_temp[:, 1]
        true_weight_total = weight_temp[:, 0]
        pred_weight_total = weight_temp[:, 1]
        true_price_weight = np.vstack((true_price_total, true_weight_total))
        pred_price_weight = np.vstack((pred_price_total, pred_weight_total))

        true_obj = LP_relax_file.actual_obj(true_price_total, true_weight_total, n_instance=instance_num)

        corr_obj_list = []
        for testNum in range(instance_num):
            # print(testNum)
            true_price = np.zeros(x_num)
            pred_price = np.zeros(x_num)
            true_weight = np.zeros(x_num)
            pred_weight = np.zeros(x_num)
            for i in range(x_num):
                true_price[i] = true_price_total[i+testNum*x_num]
                pred_price[i] = pred_price_total[i+testNum*x_num]
                true_weight[i] = true_weight_total[i+testNum*x_num]
                pred_weight[i] = pred_weight_total[i+testNum*x_num]

            corrrlst = LP_relax_file.correction_single_obj(pred_price, pred_weight, true_price, true_weight)
            corr_obj_list.append(corrrlst)
    #        print(corrrlst)

    #     # print(corr_obj_list)
        end_time = time.time()
        runtime = end_time - start_time
        print(testmark, "MSE: ", mean_squared_error(true_price_weight, pred_price_weight), end=" ")
        print("TOV: ", sum(true_obj)/instance_num, "EOV: ", sum(corr_obj_list)/instance_num, "PReg: ", sum(abs(true_obj - np.array(corr_obj_list)))/instance_num, "runtime: ", runtime)
    
