import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time

global item_num
global month_num
global x_num
global y_num
global z_num
global var_num
global trans_fee_percent
global capacity
global version_num
global lower_bound

item_num = 5
total_month_num = 4
month_num = total_month_num
x_num = item_num * month_num
y_num = item_num * (month_num - 1)
z_num = item_num * (month_num - 1)
var_num = x_num + y_num + z_num
var_num_intOpt = var_num + y_num
lower_bound = -100000
relax_val = 1e-5



def mkdir(default_path, folder_name):
    path = os.path.join(default_path, folder_name)
    folder = os.path.exists(path)
    if not folder:
        os.makedirs(path)

def set_capacity(cap):
    global capacity
    global version_num
    capacity = cap
    if cap == 25:
        version_num = 1
    elif cap == 50:
        version_num = 2
    elif cap == 75:
        version_num = 3
    elif cap == 100:
        version_num = 4

def set_trans_fee_percent(trans_fee):
    global trans_fee_percent
    trans_fee_percent = trans_fee
  
def change_month_num(cur_month_num):
    global month_num
    global x_num
    global y_num
    global z_num
    global var_num
    global var_num_intOpt
    month_num = cur_month_num
    x_num = item_num * cur_month_num
    y_num = item_num * cur_month_num
    z_num = item_num * cur_month_num
    var_num = x_num + y_num + z_num
    var_num_intOpt = var_num + y_num


def reset_month_num():
    global month_num
    global x_num
    global y_num
    global z_num
    global var_num
    global var_num_intOpt
    month_num = total_month_num
    x_num = item_num * month_num
    y_num = item_num * (month_num - 1)
    z_num = item_num * (month_num - 1)
    var_num = x_num + y_num + z_num
    var_num_intOpt = var_num + y_num


def gen_constraints(cur_stage, x_prev_sol=None):
    if cur_stage == 0:
        A = np.zeros((y_num, var_num))
        b = np.zeros(y_num)
        cons1_cnt = 0
        # x_ij = y_ij + x_(i-1)j
        for i in range(2, month_num+1):
            for j in range(item_num):
                A[cons1_cnt][item_num * (i-1) + j] = 1
                A[cons1_cnt][item_num * (i-2) + j] = -1
                A[cons1_cnt][x_num + item_num * (i-2) + j] = -1
                cons1_cnt += 1
        
        # y_ij - z_ij <= 0
        G1 = np.zeros((z_num, var_num))
        h1 = np.zeros(z_num)
        cons2_cnt = 0
        for i in range(2, month_num+1):
            for j in range(item_num):
                G1[cons2_cnt][x_num + item_num * (i-2) + j] = 1
                G1[cons2_cnt][x_num + y_num + item_num * (i-2) + j] = -1
                cons2_cnt += 1
        
        # - y_ij - z_ij <= 0
        G2 = np.zeros((z_num, var_num))
        h2 = np.zeros(z_num)
        cons3_cnt = 0
        for i in range(2, month_num+1):
            for j in range(item_num):
                G2[cons3_cnt][x_num + item_num * (i-2) + j] = -1
                G2[cons3_cnt][x_num + y_num + item_num * (i-2) + j] = -1
                cons3_cnt += 1
        
#        G3 = np.identity(var_num)
#        G3 = -G3
#        h3 = np.zeros(var_num)
#        for i in range(x_num, x_num+y_num):
#            h3[i] = 1
#        print(G3)
        
#        G4 = np.identity(var_num)
#        h4 = np.ones(var_num)

        # x,z >= 0
        G3 = np.zeros((x_num+z_num, var_num))
        cnt_G3 = 0
        for i in range(x_num):
            G3[cnt_G3][i] = -1
            cnt_G3 += 1
        for i in range(x_num+y_num, var_num):
            G3[cnt_G3][i] = -1
            cnt_G3 += 1
        h3 = np.zeros(x_num+z_num)
        
        G_fix = np.concatenate([G1, G2, G3], axis=0)
        h_fix = np.concatenate([h1, h2, h3], axis=0)
    
    else:
        # x_{ij} = y_{ij} + x_{i-1,j}
        A1 = np.zeros((y_num, var_num))
        b1 = np.zeros(y_num)
        cons1_cnt = 0
        for i in range(2, month_num+1):
            for j in range(item_num):
                A1[cons1_cnt][item_num * (i-1) + j] = 1
                A1[cons1_cnt][item_num * (i-2) + j] = -1
                A1[cons1_cnt][x_num + item_num * (i-2) + j] = -1
                cons1_cnt += 1
                
        # x_{ij} = x^{t-1}_{ij}
        A2 = np.zeros((cur_stage*item_num, var_num))
        for i in range(cur_stage*item_num):
            A2[i][i] = 1
        b2 = x_prev_sol[:cur_stage*item_num]
        
        G1 = np.zeros((z_num, var_num))
        h1 = np.zeros(z_num)
        cons2_cnt = 0
        for i in range(2, month_num+1):
            for j in range(item_num):
                G1[cons2_cnt][x_num + item_num * (i-2) + j] = 1
                G1[cons2_cnt][x_num + y_num + item_num * (i-2) + j] = -1
                cons2_cnt += 1
        
        G2 = np.zeros((z_num, var_num))
        h2 = np.zeros(z_num)
        cons3_cnt = 0
        for i in range(2, month_num+1):
            for j in range(item_num):
                G2[cons3_cnt][x_num + item_num * (i-2) + j] = -1
                G2[cons3_cnt][x_num + y_num + item_num * (i-2) + j] = -1
                cons3_cnt += 1
        
#        G3 = np.identity(var_num)
#        G3 = -G3
#        h3 = np.zeros(var_num)
#        for i in range(x_num, x_num+y_num):
#            h3[i] = 1
##        print(G3)
#
#        G4 = np.identity(var_num)
#        h4 = np.ones(var_num)

        G3 = np.zeros((x_num+z_num, var_num))
        cnt_G3 = 0
        for i in range(x_num):
            G3[cnt_G3][i] = -1
            cnt_G3 += 1
        for i in range(x_num+y_num, var_num):
            G3[cnt_G3][i] = -1
            cnt_G3 += 1
        h3 = np.zeros(x_num+z_num)
        
        A = np.concatenate([A1, A2], axis=0)
        b = np.concatenate([b1, b2], axis=0)
        G_fix = np.concatenate([G1, G2, G3], axis=0)
        h_fix = np.concatenate([h1, h2, h3], axis=0)
    
    return A, b, G_fix, h_fix
    

def gen_KS_obj_cap_cons(cur_stage, value, weight):
    c_x = np.zeros(x_num)
    for i in range(x_num):
        c_x[i] = value[i]
    for i in range(item_num):
        c_x[i] = c_x[i] - weight[i]
    for i in range(item_num*(month_num-1), x_num):
        c_x[i] = c_x[i] + weight[i]
    
    c_y = np.zeros(y_num)
    for i in range(1, month_num):
        for j in range(item_num):
            c_y[item_num*(i-1)+j] = -weight[i*item_num+j]
    
    c_z = np.zeros(z_num)
    for i in range(1, month_num):
        for j in range(item_num):
            c_z[item_num*(i-1)+j] = -trans_fee_percent*weight[i*item_num+j]
            
    c = np.concatenate([c_x, c_y, c_z], axis=0)
    
    if cur_stage == 0:
        cap_h = np.ones(month_num) * capacity
        cap_G = np.zeros((month_num, var_num))
        for k in range(month_num):
            for i in range(item_num):
                cap_G[k][i] = weight[i]
        
        for k in range(1, month_num):
            for i in range(1, k+1):
                for j in range(item_num):
                    cap_G[k][x_num + item_num * (i-1) + j] = weight[item_num * i + j]
                    cap_G[k][x_num + y_num + item_num * (i-1) + j] = trans_fee_percent * weight[item_num * i + j]
    else:
        cap_h = np.ones(month_num - cur_stage) * capacity
        cap_G = np.zeros((month_num - cur_stage, var_num))
        for k in range(cur_stage, month_num):
            for i in range(item_num):
                cap_G[k - cur_stage][i] = weight[i]
        
        for k in range(cur_stage, month_num):
            for i in range(1, k+1):
                for j in range(item_num):
                    cap_G[k - cur_stage][x_num + item_num * (i-1) + j] = weight[item_num * i + j]
                    cap_G[k - cur_stage][x_num + y_num + item_num * (i-1) + j] = trans_fee_percent * weight[item_num * i + j]
    
    return c, cap_G, cap_h


def gen_constraints_latter_months(cur_month_num, x_prev_sol):
    # x_{ij} = y_{ij} + x_{i-1,j}
    # for the current first month, x_{ij} = y_{ij} + x_prev_sol
    A = np.zeros((y_num, var_num))
    b = np.zeros(y_num)
    for j in range(item_num):
            A[j][j] = 1
            A[j][x_num + j] = -1
            b[j] = x_prev_sol[j]
    # x_{ij} = y_{ij} + x_{i-1,j}
    cons1_cnt = item_num
    for i in range(2, month_num+1):
        for j in range(item_num):
            A[cons1_cnt][item_num * (i-1) + j] = 1
            A[cons1_cnt][item_num * (i-2) + j] = -1
            A[cons1_cnt][x_num + item_num * (i-1) + j] = -1
            cons1_cnt += 1
            
    # y_ij - z_ij <= 0
    G1 = np.zeros((z_num, var_num))
    h1 = np.zeros(z_num)
    cons2_cnt = 0
    for i in range(1, month_num+1):
        for j in range(item_num):
            G1[cons2_cnt][x_num + item_num * (i-1) + j] = 1
            G1[cons2_cnt][x_num + y_num + item_num * (i-1) + j] = -1
            cons2_cnt += 1
    
    # - y_ij - z_ij <= 0 ==> m_ij - z_ij <= 0
    G2 = np.zeros((z_num, var_num))
    h2 = np.zeros(z_num)
    cons3_cnt = 0
    for i in range(1, month_num+1):
        for j in range(item_num):
            G2[cons3_cnt][x_num + item_num * (i-1) + j] = -1
            G2[cons3_cnt][x_num + y_num + item_num * (i-1) + j] = -1
            cons3_cnt += 1

    G3 = np.zeros((x_num+z_num, var_num))
    cnt_G3 = 0
    for i in range(x_num):
        G3[cnt_G3][i] = -1
        cnt_G3 += 1
    for i in range(x_num+y_num, var_num):
        G3[cnt_G3][i] = -1
        cnt_G3 += 1
    h3 = np.zeros(x_num+z_num)
    
    G_fix = np.concatenate([G1, G2, G3], axis=0)
    h_fix = np.concatenate([h1, h2, h3], axis=0)
    
    return A, b, G_fix, h_fix


def gen_KS_obj_cap_cons_latter_month(cur_month_num, prev_cost, value, weight):
    c_x = np.zeros(x_num)
    for i in range(x_num):
        c_x[i] = value[i]
    for i in range(item_num*(month_num-1), x_num):
        c_x[i] = c_x[i] + weight[i]
    
    c_y = np.zeros(y_num)
    for i in range(month_num):
        for j in range(item_num):
            c_y[item_num*i+j] = -weight[i*item_num+j]
    
    c_z = np.zeros(z_num)
    for i in range(month_num):
        for j in range(item_num):
            c_z[item_num*i+j] = -trans_fee_percent*weight[i*item_num+j]
            
    c = np.concatenate([c_x, c_y, c_z], axis=0)
    
    cap_h = np.ones(month_num) * (capacity-prev_cost)
    cap_G = np.zeros((month_num, var_num))
    
#    print(x_num, month_num)
    for k in range(1, month_num+1):
#        print(k)
        for i in range(k):
            for j in range(item_num):
#                print(k, i, j, x_num + item_num * i + j)
                cap_G[k-1][x_num + item_num * i + j] = weight[item_num * i + j]
                cap_G[k-1][x_num + y_num + item_num * i + j] = trans_fee_percent * weight[item_num * i + j]
#    print(cap_G)
    return c, cap_G, cap_h


def gen_constraints_latter_months_intOpt(cur_month_num, x_prev_sol):
    # x_ij = y_ij + x_(i-1)j ==> x_ij = y_ij - m_ij + x_(i-1)j
    # for the current first month, x_{ij} = y_{ij} - m_ij + x_prev_sol
    A = np.zeros((y_num, var_num_intOpt))
    b = np.zeros(y_num)
    for j in range(item_num):
            A[j][j] = 1
            A[j][x_num + j] = -1
            A[j][var_num + j] = 1
            b[j] = x_prev_sol[j]
    # x_{ij} = y_{ij} + x_{i-1,j}
    cons1_cnt = item_num
    for i in range(2, month_num+1):
        for j in range(item_num):
            A[cons1_cnt][item_num * (i-1) + j] = 1
            A[cons1_cnt][item_num * (i-2) + j] = -1
            A[cons1_cnt][x_num + item_num * (i-1) + j] = -1
            A[cons1_cnt][var_num + item_num * (i-1) + j] = 1
            cons1_cnt += 1
            
    # y_ij - z_ij <= 0
    G1 = np.zeros((z_num, var_num_intOpt))
    h1 = np.zeros(z_num)
    cons2_cnt = 0
    for i in range(1, month_num+1):
        for j in range(item_num):
            G1[cons2_cnt][x_num + item_num * (i-1) + j] = 1
            G1[cons2_cnt][x_num + y_num + item_num * (i-1) + j] = -1
            cons2_cnt += 1
    
    # - y_ij - z_ij <= 0 ==> m_ij - z_ij <= 0
    G2 = np.zeros((z_num, var_num_intOpt))
    h2 = np.zeros(z_num)
    cons3_cnt = 0
    for i in range(1, month_num+1):
        for j in range(item_num):
            G2[cons3_cnt][var_num + item_num * (i-1) + j] = 1
            G2[cons3_cnt][x_num + y_num + item_num * (i-1) + j] = -1
            cons3_cnt += 1

#    G3 = np.zeros((x_num+z_num, var_num_intOpt))
#    cnt_G3 = 0
#    for i in range(x_num):
#        G3[cnt_G3][i] = -1
#        cnt_G3 += 1
#    for i in range(x_num+y_num, var_num_intOpt):
#        G3[cnt_G3][i] = -1
#        cnt_G3 += 1
#    h3 = np.zeros(x_num+z_num)
    
    G_fix = np.concatenate([G1, G2], axis=0)
    h_fix = np.concatenate([h1, h2], axis=0)
    
    return A, b, G_fix, h_fix


def gen_KS_obj_cap_cons_latter_month_intOpt(cur_month_num, prev_cost, value, weight):
    c_x = np.zeros(x_num)
    for i in range(x_num):
        c_x[i] = value[i]
    for i in range(item_num*(month_num-1), x_num):
        c_x[i] = c_x[i] + weight[i]
    
    c_y = np.zeros(y_num)
    for i in range(month_num):
        for j in range(item_num):
            c_y[item_num*i+j] = -weight[i*item_num+j]
    
    c_z = np.zeros(z_num)
    for i in range(month_num):
        for j in range(item_num):
            c_z[item_num*i+j] = -trans_fee_percent*weight[i*item_num+j]
    
    # intOpt only allow non-negative variables, m variables are for selling products
    c_m = np.zeros(y_num)
    for i in range(month_num):
        for j in range(item_num):
            c_m[item_num*i+j] = weight[i*item_num+j]
            
    c = np.concatenate([c_x, c_y, c_z, c_m], axis=0)
    
    cap_h = np.ones(month_num) * (capacity-prev_cost)
    cap_G = np.zeros((month_num, var_num_intOpt))
    
#    print(x_num, month_num)
    for k in range(1, month_num+1):
#        print(k)
        for i in range(k):
            for j in range(item_num):
#                print(k, i, j, x_num + item_num * i + j)
                cap_G[k-1][x_num + item_num * i + j] = weight[item_num * i + j]
                cap_G[k-1][x_num + y_num + item_num * i + j] = trans_fee_percent * weight[item_num * i + j]
                cap_G[k-1][var_num + item_num * i + j] = -weight[item_num * i + j]
#    print(cap_G)
    return c, cap_G, cap_h


# cur_month_num corresponds to cur_NN (cur_NN = total_month_num - cur_month_num)
def gen_constraints_latter_months_full_intOpt(cur_month_num, cur_stage, x_prev_sol, x_prev_stage=None):
    # x_ij = y_ij + x_(i-1)j ==> x_ij = y_ij - m_ij + x_(i-1)j
    # for the current first month, x_{ij} = y_{ij} - m_ij + x_prev_sol
    A1 = np.zeros((y_num, var_num_intOpt))
    b1 = np.zeros(y_num)
    for j in range(item_num):
            A1[j][j] = 1
            A1[j][x_num + j] = -1
            A1[j][var_num + j] = 1
            b1[j] = x_prev_sol[j]
    # x_{ij} = y_{ij} + x_{i-1,j}
    cons1_cnt = item_num
    for i in range(2, month_num+1):
        for j in range(item_num):
            A1[cons1_cnt][item_num * (i-1) + j] = 1
            A1[cons1_cnt][item_num * (i-2) + j] = -1
            A1[cons1_cnt][x_num + item_num * (i-1) + j] = -1
            A1[cons1_cnt][var_num + item_num * (i-1) + j] = 1
            cons1_cnt += 1
            
    # y_ij - z_ij <= 0
    G1 = np.zeros((z_num, var_num_intOpt))
    h1 = np.zeros(z_num)
    cons2_cnt = 0
    for i in range(1, month_num+1):
        for j in range(item_num):
            G1[cons2_cnt][x_num + item_num * (i-1) + j] = 1
            G1[cons2_cnt][x_num + y_num + item_num * (i-1) + j] = -1
            cons2_cnt += 1
    
    # - y_ij - z_ij <= 0 ==> m_ij - z_ij <= 0
    G2 = np.zeros((z_num, var_num_intOpt))
    h2 = np.zeros(z_num)
    cons3_cnt = 0
    for i in range(1, month_num+1):
        for j in range(item_num):
            G2[cons3_cnt][var_num + item_num * (i-1) + j] = 1
            G2[cons3_cnt][x_num + y_num + item_num * (i-1) + j] = -1
            cons3_cnt += 1

#    G3 = np.zeros((x_num+z_num, var_num_intOpt))
#    cnt_G3 = 0
#    for i in range(x_num):
#        G3[cnt_G3][i] = -1
#        cnt_G3 += 1
#    for i in range(x_num+y_num, var_num_intOpt):
#        G3[cnt_G3][i] = -1
#        cnt_G3 += 1
#    h3 = np.zeros(x_num+z_num)
    
    G_fix = np.concatenate([G1, G2], axis=0)
    h_fix = np.concatenate([h1, h2], axis=0)
    
    if cur_stage == total_month_num - cur_month_num:
#        G_fix = np.concatenate([G1, G2], axis=0)
#        h_fix = np.concatenate([h1, h2], axis=0)
        A = A1
        b = b1
    else:
        # x_{ij} = x^{t-1}_{ij} ==> x^{t-1}_{ij} - relax_val <= x_{ij} <= x^{t-1}_{ij} + relax_val
        fix_month_num = cur_stage - total_month_num + cur_month_num
        A2 = np.zeros((fix_month_num*item_num, var_num_intOpt))
        for i in range(fix_month_num*item_num):
            A2[i][i] = 1
        b2 = np.zeros(fix_month_num*item_num)
        for i in range(fix_month_num*item_num):
            b2[i] = round(x_prev_stage[i], 2)
#        print(b2)
#        b2 = x_prev_sol[:cur_stage*item_num]

#        for i in range(cur_stage*item_num):
#            x_prev_sol[i] = round(x_prev_sol[i], 5)
#        G3 = np.zeros((cur_stage*item_num, var_num_intOpt))
#        for i in range(cur_stage*item_num):
#            G3[i][i] = 1
#        h3 = x_prev_sol[:cur_stage*item_num] + relax_val
#        G4 = np.zeros((cur_stage*item_num, var_num_intOpt))
#        for i in range(cur_stage*item_num):
#            G4[i][i] = -1
#        h4 = relax_val - x_prev_sol[:cur_stage*item_num]
#        G_fix = np.concatenate([G1, G2, G3, G4], axis=0)
#        h_fix = np.concatenate([h1, h2, h3, h4], axis=0)
#
        A = np.concatenate([A1, A2], axis=0)
        b = np.concatenate([b1, b2], axis=0)
    
    return A, b, G_fix, h_fix


def gen_KS_obj_cap_cons_latter_month_full_intOpt(cur_month_num, cur_stage, prev_cost, value, weight):
    c_x = np.zeros(x_num)
    for i in range(x_num):
        c_x[i] = value[i]
    for i in range(item_num*(month_num-1), x_num):
        c_x[i] = c_x[i] + weight[i]
    
    c_y = np.zeros(y_num)
    for i in range(month_num):
        for j in range(item_num):
            c_y[item_num*i+j] = -weight[i*item_num+j]
    
    c_z = np.zeros(z_num)
    for i in range(month_num):
        for j in range(item_num):
            c_z[item_num*i+j] = -trans_fee_percent*weight[i*item_num+j]
    
    # intOpt only allow non-negative variables, m variables are for selling products
    c_m = np.zeros(y_num)
    for i in range(month_num):
        for j in range(item_num):
            c_m[item_num*i+j] = weight[i*item_num+j]
            
    c = np.concatenate([c_x, c_y, c_z, c_m], axis=0)
    
    
    if cur_stage == total_month_num - cur_month_num:
        cap_h = np.ones(month_num) * (capacity-prev_cost)
        cap_G = np.zeros((month_num, var_num_intOpt))
        
    #    print(x_num, month_num)
        for k in range(1, month_num+1):
    #        print(k)
            for i in range(k):
                for j in range(item_num):
    #                print(k, i, j, x_num + item_num * i + j)
                    cap_G[k-1][x_num + item_num * i + j] = weight[item_num * i + j]
                    cap_G[k-1][x_num + y_num + item_num * i + j] = trans_fee_percent * weight[item_num * i + j]
                    cap_G[k-1][var_num + item_num * i + j] = -weight[item_num * i + j]
    
    else:
#        cap_h = np.ones(month_num - cur_stage + 1) * (capacity-prev_cost)
#        cap_G = np.zeros((month_num - cur_stage + 1, var_num_intOpt))
        cap_h = np.ones(total_month_num - cur_stage) * (capacity-prev_cost)
        cap_G = np.zeros((total_month_num - cur_stage, var_num_intOpt))
        
        cur_NN = total_month_num - cur_month_num
#        for k in range(cur_stage, month_num+1):
        for k in range(cur_stage+1, total_month_num+1):
            for i in range(cur_NN, k):
                for j in range(item_num):
                    cap_G[k - cur_stage - 1][x_num + item_num * (i-cur_NN) + j] = weight[item_num * (i-cur_NN) + j]
                    cap_G[k - cur_stage - 1][x_num + y_num + item_num * (i-cur_NN) + j] = trans_fee_percent * weight[item_num * (i-cur_NN) + j]
                    cap_G[k - cur_stage - 1][var_num + item_num * (i-cur_NN) + j] = -weight[item_num * (i-cur_NN) + j]
    
#    print(cap_G)
    return c, cap_G, cap_h

    
def gen_constraints_intOpt(cur_stage, x_prev_sol=None):
    A1 = np.zeros((y_num, var_num_intOpt))
    b1 = np.zeros(y_num)
    cons1_cnt = 0
    # x_ij = y_ij + x_(i-1)j ==> x_ij = y_ij - m_ij + x_(i-1)j
    for i in range(2, month_num+1):
        for j in range(item_num):
            A1[cons1_cnt][item_num * (i-1) + j] = 1
            A1[cons1_cnt][item_num * (i-2) + j] = -1
            A1[cons1_cnt][x_num + item_num * (i-2) + j] = -1
            A1[cons1_cnt][var_num + item_num * (i-2) + j] = 1
            cons1_cnt += 1
    
    # y_ij - z_ij <= 0
    G1 = np.zeros((z_num, var_num_intOpt))
    h1 = np.zeros(z_num)
    cons2_cnt = 0
    for i in range(2, month_num+1):
        for j in range(item_num):
            G1[cons2_cnt][x_num + item_num * (i-2) + j] = 1
            G1[cons2_cnt][x_num + y_num + item_num * (i-2) + j] = -1
            cons2_cnt += 1
    
    # - y_ij - z_ij <= 0 ==> m_ij - z_ij <= 0
    G2 = np.zeros((z_num, var_num_intOpt))
    h2 = np.zeros(z_num)
    cons3_cnt = 0
    for i in range(2, month_num+1):
        for j in range(item_num):
            G2[cons3_cnt][var_num + item_num * (i-2) + j] = 1
            G2[cons3_cnt][x_num + y_num + item_num * (i-2) + j] = -1
            cons3_cnt += 1
        
#        G3 = np.identity(var_num)
#        G3 = -G3
#        h3 = np.zeros(var_num)
#        for i in range(x_num, x_num+y_num):
#            h3[i] = 1
#        print(G3)
        
#        G4 = np.identity(var_num)
#        h4 = np.ones(var_num)

        # x,y,z,m >= 0
#        G3 = np.zeros((x_num+z_num, var_num))
#        cnt_G3 = 0
#        for i in range(x_num):
#            G3[cnt_G3][i] = -1
#            cnt_G3 += 1
#        for i in range(x_num+y_num, var_num):
#            G3[cnt_G3][i] = -1
#            cnt_G3 += 1
#        h3 = np.zeros(x_num+z_num)
    G_fix = np.concatenate([G1, G2], axis=0)
    h_fix = np.concatenate([h1, h2], axis=0)
    
    if cur_stage == 0:
#        G_fix = np.concatenate([G1, G2], axis=0)
#        h_fix = np.concatenate([h1, h2], axis=0)
        A = A1
        b = b1
    else:
        # x_{ij} = x^{t-1}_{ij} ==> x^{t-1}_{ij} - relax_val <= x_{ij} <= x^{t-1}_{ij} + relax_val
        A2 = np.zeros((cur_stage*item_num, var_num_intOpt))
        for i in range(cur_stage*item_num):
            A2[i][i] = 1
        b2 = np.zeros(cur_stage*item_num)
        for i in range(cur_stage*item_num):
            b2[i] = round(x_prev_sol[i], 2)
#        print(b2)
#        b2 = x_prev_sol[:cur_stage*item_num]

#        for i in range(cur_stage*item_num):
#            x_prev_sol[i] = round(x_prev_sol[i], 5)
#        G3 = np.zeros((cur_stage*item_num, var_num_intOpt))
#        for i in range(cur_stage*item_num):
#            G3[i][i] = 1
#        h3 = x_prev_sol[:cur_stage*item_num] + relax_val
#        G4 = np.zeros((cur_stage*item_num, var_num_intOpt))
#        for i in range(cur_stage*item_num):
#            G4[i][i] = -1
#        h4 = relax_val - x_prev_sol[:cur_stage*item_num]
#        G_fix = np.concatenate([G1, G2, G3, G4], axis=0)
#        h_fix = np.concatenate([h1, h2, h3, h4], axis=0)
#
        A = np.concatenate([A1, A2], axis=0)
        b = np.concatenate([b1, b2], axis=0)
#        A = A1
#        b = b1
    
    return A, b, G_fix, h_fix


def gen_KS_obj_cap_cons_intOpt(cur_stage, value, weight):
    c_x = np.zeros(x_num)
    for i in range(x_num):
        c_x[i] = value[i]
    for i in range(item_num):
        c_x[i] = c_x[i] - weight[i]
    for i in range(item_num*(month_num-1), x_num):
        c_x[i] = c_x[i] + weight[i]
    
    c_y = np.zeros(y_num)
    for i in range(1, month_num):
        for j in range(item_num):
            c_y[item_num*(i-1)+j] = -weight[i*item_num+j]
    
    c_z = np.zeros(z_num)
    for i in range(1, month_num):
        for j in range(item_num):
            c_z[item_num*(i-1)+j] = -trans_fee_percent*weight[i*item_num+j]
    
    # intOpt only allow non-negative variables, m variables are for selling products
    c_m = np.zeros(y_num)
    for i in range(1, month_num):
        for j in range(item_num):
            c_m[item_num*(i-1)+j] = weight[i*item_num+j]
            
    c = np.concatenate([c_x, c_y, c_z, c_m], axis=0)
    
    if cur_stage == 0:
        cap_h = np.ones(month_num) * capacity
        cap_G = np.zeros((month_num, var_num_intOpt))
        for k in range(month_num):
            for i in range(item_num):
                cap_G[k][i] = weight[i]
        
        for k in range(1, month_num):
            for i in range(1, k+1):
                for j in range(item_num):
                    cap_G[k][x_num + item_num * (i-1) + j] = weight[item_num * i + j]
                    cap_G[k][x_num + y_num + item_num * (i-1) + j] = trans_fee_percent * weight[item_num * i + j]
                    cap_G[k][var_num + item_num * (i-1) + j] = -weight[item_num * i + j]
    else:
        cap_h = np.ones(month_num - cur_stage) * capacity
        cap_G = np.zeros((month_num - cur_stage, var_num_intOpt))
        for k in range(cur_stage, month_num):
            for i in range(item_num):
                cap_G[k - cur_stage][i] = weight[i]
        
        for k in range(cur_stage, month_num):
            for i in range(1, k+1):
                for j in range(item_num):
                    cap_G[k - cur_stage][x_num + item_num * (i-1) + j] = weight[item_num * i + j]
                    cap_G[k - cur_stage][x_num + y_num + item_num * (i-1) + j] = trans_fee_percent * weight[item_num * i + j]
                    cap_G[k - cur_stage][var_num + item_num * (i-1) + j] = -weight[item_num * i + j]
    
    return c, cap_G, cap_h

    
def actual_obj(valueTemp, weightTemp, n_instance):
    A, b, G_fix, h_fix = gen_constraints(0)
    A_row_size = A.shape[0]
    A = A.tolist()
    b = b.tolist()
    obj_list = []
    for num in range(n_instance):
        weight = np.zeros(item_num * month_num)
        value = np.zeros(item_num * month_num)
        cnt = num * item_num * month_num
        for i in range(item_num * month_num):
            weight[i] = weightTemp[cnt]
            value[i] = valueTemp[cnt]
            cnt = cnt + 1
        
#        print("min weight of each item: ", end=" ")
#        for i in range(item_num):
#            weight_of_item_i = np.zeros(month_num)
#            for j in range(month_num):
#                weight_of_item_i[j] = weight[item_num*j+i]
#            print(np.min(weight_of_item_i), end=" ")
#        print("")
        
        c, cap_G, cap_h = gen_KS_obj_cap_cons(0, value, weight)
#        np.savetxt('c.txt', c, fmt="%.2f")
#        np.savetxt('cap_G.txt', cap_G, fmt="%.2f")
        G = np.concatenate([cap_G, G_fix], axis=0)
        h = np.concatenate([cap_h, h_fix], axis=0)
        G_row_size = G.shape[0]
#        print(G.shape, h.shape)
        c = c.tolist()
        G = G.tolist()
        h = h.tolist()
        
        m = gp.Model()
        m.setParam('OutputFlag', 0)
        x = m.addVars(var_num, vtype=GRB.INTEGER, lb=lower_bound, name='x')
        m.setObjective(x.prod(c), GRB.MAXIMIZE)
        for i in range(A_row_size):
            m.addConstr((x.prod(A[i])) == b[i])
        for i in range(G_row_size):
            m.addConstr((x.prod(G[i])) <= h[i])

        m.optimize()
        sol = np.zeros(var_num)
        for i in range(var_num):
            sol[i] = x[i].x

#        final_own = sol[item_num*(month_num-1):x_num]
#        print("TOV: ", final_own)
#        x_sol = sol[:x_num]
#        print("x_sol: ")
#        for i in range(x_num):
#            print(x_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
#        y_sol = sol[x_num:x_num+y_num]
#        print("y_sol: ")
#        for i in range(y_num):
#            print("{:.2f}".format(y_sol[i]), end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
#        z_sol = sol[x_num+y_num:]
#        print("z_sol: ")
#        for i in range(z_num):
#            print(z_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
        
        objective = m.objVal
        obj_list.append(objective)
#        print(objective)
        
    return np.array(obj_list)


def get_init_plan(pred_value, pred_weight, true_weight, A, b, G_fix, h_fix):
    pred_weight_temp = pred_weight[item_num:]
    true_weight_temp = true_weight[:item_num]
    weight = np.concatenate([true_weight_temp, pred_weight_temp], axis=0)
    c, cap_G, cap_h = gen_KS_obj_cap_cons(0, pred_value, weight)
    G = np.concatenate([cap_G, G_fix], axis=0)
    h = np.concatenate([cap_h, h_fix], axis=0)
    G_row_size = G.shape[0]
    A_row_size = A.shape[0]
#        print(G.shape, h.shape)
    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    h = h.tolist()
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(var_num, vtype=GRB.INTEGER, lb=lower_bound, name='x')
    m.setObjective(x.prod(c), GRB.MAXIMIZE)
    for i in range(A_row_size):
        m.addConstr((x.prod(A[i])) == b[i])
    for i in range(G_row_size):
        m.addConstr((x.prod(G[i])) <= h[i])

    m.optimize()
    
    try:
        sol = np.zeros(var_num)
        for i in range(var_num):
            sol[i] = x[i].x
        
        x_sol = sol[:x_num]
        y_sol = sol[x_num:x_num+y_num]
        z_sol = sol[x_num+y_num:]
#        print("init_x: ")
#        for i in range(x_num):
#            print(x_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
##        print("y_sol: ")
##        for i in range(y_num):
##            print(y_sol[i], end=" ")
##            if (i+1) % item_num == 0:
##                print("\n")
##        print("z_sol: ")
##        for i in range(z_num):
##            print(z_sol[i], end=" ")
##            if (i+1) % item_num == 0:
##                print("\n")
#
#        objective = m.objVal
#        capacity_used = np.zeros(cap_h.size)
#        for i in range(cap_h.size):
#            capacity_used[i] = np.dot(sol, cap_G[i])
#        print("Investment income: ", objective)
#        print("Capacity used: ", capacity_used)
    
    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
        x_sol = np.zeros(x_num)
        y_sol = np.zeros(y_num)
        z_sol = np.zeros(z_num)
        objective = 0
    
    return x_sol, y_sol, z_sol

def get_t_updated_plan(t, x_prev_sol, y_prev_sol, z_prev_sol, prev_prof, prev_cost, pred_value, pred_weight, true_value, true_weight, A, b, G_fix, h_fix):
    true_value_temp = true_value[:item_num*t]
    pred_value_temp = pred_value[item_num*t:]
    true_weight_temp = true_weight[:item_num*(t+1)]
    pred_weight_temp = pred_weight[item_num*(t+1):]
    value = np.concatenate([true_value_temp, pred_value_temp], axis=0)
    weight = np.concatenate([true_weight_temp, pred_weight_temp], axis=0)
#    print("used weight: ", weight)
#    print("used profit: ", value)
    c, cap_G, cap_h = gen_KS_obj_cap_cons(t, value, weight)
    G = np.concatenate([cap_G, G_fix], axis=0)
    h = np.concatenate([cap_h, h_fix], axis=0)
    cap_G_row_size = cap_G.shape[0]
    G_row_size = G.shape[0]
    A_row_size = A.shape[0]
#        print(G.shape, h.shape)

    # Split G, A for x, y, z
    G_x = G[:, (t-1)*item_num:x_num]
    G_y = G[:, x_num+(t-1)*item_num:x_num+y_num]
    G_z = G[:, x_num+y_num+(t-1)*item_num:]
    A_x = A[:, (t-1)*item_num:x_num]
    A_y = A[:, x_num+(t-1)*item_num:x_num+y_num]
    A_z = A[:, x_num+y_num+(t-1)*item_num:]
    c_x = c[(t-1)*item_num:x_num]
    c_y = c[x_num+(t-1)*item_num:x_num+y_num]
    c_z = c[x_num+y_num+(t-1)*item_num:]
    G_use = np.concatenate([G_x, G_y, G_z], axis=1)
    A_use = np.concatenate([A_x, A_y, A_z], axis=1)
    c_use = np.concatenate([c_x, c_y, c_z], axis=0)
    
    for i in range((t-1)*item_num):
        b[y_num+i] = 0
        A_use[i] = 0

    c = c_use.tolist()
    A = A_use.tolist()
    b = b.tolist()
    G = G_use.tolist()
    h = h.tolist()
    
#    np.savetxt('G.txt', G, fmt="%.2f")
#    np.savetxt('h.txt', h, fmt="%.2f")
#    np.savetxt('A.txt', A, fmt="%.2f")
#    np.savetxt('b.txt', b, fmt="%.2f")
#    np.savetxt('c.txt', c, fmt="%.2f")
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(var_num-3*(t-1)*item_num, vtype=GRB.INTEGER, lb=lower_bound, name='x')
    m.setObjective(prev_prof-prev_cost+x.prod(c), GRB.MAXIMIZE)
    for i in range(A_row_size):
        m.addConstr((x.prod(A[i])) == b[i])
    for i in range(cap_G_row_size):
        m.addConstr((x.prod(G[i])) <= h[i] - prev_cost)
    for i in range(cap_G_row_size, G_row_size):
        m.addConstr((x.prod(G[i])) <= h[i])

    m.optimize()
    
    try:
        sol = np.zeros(var_num-3*(t-1)*item_num)
        for i in range(var_num-3*(t-1)*item_num):
            sol[i] = x[i].x
        
        new_x_num = x_num - (t-1) * item_num
        new_y_num = y_num - (t-1) * item_num
        new_z_num = z_num - (t-1) * item_num
        x_sol = np.concatenate([x_prev_sol[:(t-1)*item_num], sol[:new_x_num]], axis=0)
        y_sol = np.concatenate([y_prev_sol[:(t-1)*item_num], sol[new_x_num:new_x_num+new_y_num]], axis=0)
        z_sol = np.concatenate([z_prev_sol[:(t-1)*item_num], sol[new_x_num+new_y_num:]], axis=0)
        final_sol = np.concatenate([x_sol, y_sol, z_sol], axis=0)
        
#        print(str(t) + "_x: ")
#        for i in range(x_num):
#            print(x_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
#        print("y_sol: ")
#        for i in range(y_num):
#            print(y_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
#        print("z_sol: ")
#        for i in range(z_num):
#            print(z_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
#
#        objective = m.objVal
#        capacity_used = np.zeros(cap_h.size)
#        for i in range(cap_h.size):
#            capacity_used[i] = np.dot(final_sol, cap_G[i])
#        print("Investment income: ", objective)
#        print("Capacity used: ", capacity_used)
    
    except:
#        print(t, "cannot solve")
##        print(cap_G)
#        m.computeIIS()
#        m.write('model.ilp')
        x_sol = np.zeros(x_num)
        y_sol = np.zeros(y_num)
        z_sol = np.zeros(z_num)
        objective = 0
    
    return x_sol, y_sol, z_sol


def get_updated_plan_for_each_month(cur_month_num, prev_prof, prev_cost, pred_value, pred_weight,
    true_value, true_weight, A, b, G_fix, h_fix):
    
    true_weight_temp = true_weight[:item_num]
    pred_weight_temp = pred_weight[item_num:]
    weight = np.concatenate([true_weight_temp, pred_weight_temp], axis=0)
#    print("used weight: ", weight)
#    print("used profit: ", pred_value)
    c, cap_G, cap_h = gen_KS_obj_cap_cons_latter_month(cur_month_num, prev_cost, pred_value, weight)
#    print(prev_cost)
    G = np.concatenate([cap_G, G_fix], axis=0)
    h = np.concatenate([cap_h, h_fix], axis=0)
    cap_G_row_size = cap_G.shape[0]
    G_row_size = G.shape[0]
    A_row_size = A.shape[0]
#        print(G.shape, h.shape)

    c = c.tolist()
    A = A.tolist()
    b = b.tolist()
    G = G.tolist()
    h = h.tolist()
    
#    np.savetxt('G.txt', G, fmt="%.2f")
#    np.savetxt('h.txt', h, fmt="%.2f")
#    np.savetxt('A.txt', A, fmt="%.2f")
#    np.savetxt('b.txt', b, fmt="%.2f")
#    np.savetxt('c.txt', c, fmt="%.2f")
    
    m = gp.Model()
    m.setParam('OutputFlag', 0)
    x = m.addVars(var_num, vtype=GRB.INTEGER, lb=lower_bound, name='x')
    m.setObjective(prev_prof-prev_cost+x.prod(c), GRB.MAXIMIZE)
    for i in range(A_row_size):
        m.addConstr((x.prod(A[i])) == b[i])
    for i in range(G_row_size):
        m.addConstr((x.prod(G[i])) <= h[i])

    m.optimize()
#    print(m.SolCount)
    
    try:
        sol = np.zeros(var_num)
        for i in range(var_num):
            sol[i] = x[i].x

        x_sol = sol[:x_num]
        y_sol = sol[x_num:x_num+y_num]
        z_sol = sol[x_num+y_num:]
#        print(x_sol, y_sol, z_sol)
#        print(str(total_month_num-cur_month_num) + "_x: ")
#        for i in range(x_num):
#            print(x_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
#        print("y_sol: ")
#        for i in range(y_num):
#            print(y_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
#        print("z_sol: ")
#        for i in range(z_num):
#            print(z_sol[i], end=" ")
#            if (i+1) % item_num == 0:
#                print("\n")
#
#        objective = m.objVal
#        capacity_used = np.zeros(cap_h.size)
#        for i in range(cap_h.size):
#            print(sol, cap_G[i])
#            capacity_used[i] = np.dot(sol, cap_G[i])
#        print("Investment income: ", objective)
#        print("Capacity used: ", capacity_used)
    
    except:
#        print("cannot solve")
#        m.computeIIS()
#        m.write('model.ilp')
#        time.sleep(100)
        x_sol = np.zeros(x_num)
        y_sol = np.zeros(y_num)
        z_sol = np.zeros(z_num)
        objective = 0
#    print(x_sol.shape)
    return x_sol, y_sol, z_sol
    
    
def correction_single_obj(pred_value, pred_weight, true_value, true_weight):
#    for i in range(x_num):
#        print(true_weight[i], end=" ")
#        if (i+1) % item_num == 0:
#            print("\n")
    
    A_0, b_0, G_0, h_0 = gen_constraints(0)
    init_x, init_y, init_z = get_init_plan(pred_value, pred_weight, true_weight, A_0, b_0, G_0, h_0)

    t_updated_x = init_x
    t_updated_y = init_y
    t_updated_z = init_z
    for t in range(1, month_num):
#        print(t)
        # compute current states
        x_prev_sol = t_updated_x
        y_prev_sol = t_updated_y
        z_prev_sol = t_updated_z
        prev_prof = np.dot(true_value[:(t-1)*item_num], x_prev_sol[:(t-1)*item_num])
        if t > 1:
            prev_cost = np.dot(true_weight[:item_num], x_prev_sol[:item_num]) + trans_fee_percent*np.dot(true_weight[item_num:t*item_num], z_prev_sol[:(t-1)*item_num]) + np.dot(true_weight[item_num:t*item_num], y_prev_sol[:(t-1)*item_num])
#            print(t, prev_cost)
        else:
            prev_cost = 0
#            print(t, prev_cost)
        
        A_t, b_t, G_t, h_t = gen_constraints(t, x_prev_sol=t_updated_x)
        t_updated_x, t_updated_y, t_updated_z = get_t_updated_plan(t, x_prev_sol, y_prev_sol, z_prev_sol, prev_prof, prev_cost, pred_value, pred_weight, true_value, true_weight, A_t, b_t, G_t, h_t)
    
    c, cap_G, cap_h = gen_KS_obj_cap_cons(month_num, true_value, true_weight)
    final_sol = np.concatenate([t_updated_x, t_updated_y, t_updated_z], axis=0)
    real_c, cap_G, cap_h = gen_KS_obj_cap_cons(0, true_value, true_weight)
    investment_income = np.dot(real_c, final_sol)
#    final_own = final_sol[item_num*(month_num-1):x_num]
#    print("EOV: ", final_own)
#    print("Investment income: ", investment_income)
#    time.sleep(10)
    return investment_income


def correction_single_for_each_month(pred_value, pred_weight, true_value, true_weight):
    A_0, b_0, G_0, h_0 = gen_constraints(0)
    init_x, init_y, init_z = get_init_plan(pred_value, pred_weight, true_weight, A_0, b_0, G_0, h_0)

    t_updated_x = init_x
    t_updated_y = init_y
    t_updated_z = init_z
    prev_cost = 0
    prev_prof = 0
    for t in range(1, total_month_num):
        cur_month_num = total_month_num - t
        change_month_num(cur_month_num)
        # compute current states
        x_prev_sol = t_updated_x[:item_num]
        y_prev_sol = t_updated_y[:item_num]
        z_prev_sol = t_updated_z[:item_num]
        prev_prof = prev_prof + np.dot(true_value[(t-1)*item_num:t*item_num], x_prev_sol)
        if t == 1:
            prev_cost = prev_cost + np.dot(true_weight[:item_num], x_prev_sol)
        else:
            prev_cost = prev_cost + trans_fee_percent*np.dot(true_weight[(t-1)*item_num:t*item_num], z_prev_sol) + np.dot(true_weight[(t-1)*item_num:t*item_num], y_prev_sol)
        pred_value_cur = pred_value[t*item_num:]
        pred_weight_cur = pred_weight[t*item_num:]
        true_value_cur = true_value[t*item_num:]
        true_weight_cur = true_weight[t*item_num:]
        A_t, b_t, G_t, h_t = gen_constraints_latter_months(cur_month_num, x_prev_sol)
        t_updated_x, t_updated_y, t_updated_z = get_updated_plan_for_each_month(cur_month_num, prev_prof, prev_cost, pred_value_cur, pred_weight_cur, true_value_cur, true_weight_cur, A_t, b_t, G_t, h_t)
#        print(prev_prof, prev_cost)
    
    final_prof = prev_prof + np.dot(true_value[(total_month_num-1)*item_num:], t_updated_x) + np.dot(true_weight[(total_month_num-1)*item_num:], t_updated_x)
    final_cost = prev_cost + trans_fee_percent*np.dot(true_weight[(total_month_num-1)*item_num:], t_updated_z) + np.dot(true_weight[(total_month_num-1)*item_num:], t_updated_y)
    investment_income = final_prof - final_cost
#    print("final_prof: ", final_prof, "final_cost: ", final_cost)
#    change_month_num(total_month_num)
    
    return investment_income


def correction_single_for_latter_month(cur_NN, pred_value, pred_weight, true_value, true_weight, init_x, init_y, init_z, prev_cost, prev_prof):

    t_updated_x = init_x[(cur_NN-1)*item_num:]
    t_updated_y = init_y[(cur_NN-1)*item_num:]
    t_updated_z = init_z[(cur_NN-1)*item_num:]

    for t in range(cur_NN, total_month_num):
        cur_month_num = total_month_num - t
        change_month_num(cur_month_num)
        # compute current states
        x_prev_sol = t_updated_x[:item_num]
        y_prev_sol = t_updated_y[:item_num]
        z_prev_sol = t_updated_z[:item_num]
        if t > cur_NN:
            prev_prof = prev_prof + np.dot(true_value[(t-cur_NN-1)*item_num:(t-cur_NN)*item_num], x_prev_sol)
            prev_cost = prev_cost + trans_fee_percent*np.dot(true_weight[(t-cur_NN-1)*item_num:(t-cur_NN)*item_num], z_prev_sol) + np.dot(true_weight[(t-cur_NN-1)*item_num:(t-cur_NN)*item_num], y_prev_sol)

        true_value_cur = true_value[(t-cur_NN)*item_num:]
        pred_value_cur = pred_value[(t-cur_NN)*item_num:]
        true_weight_cur = true_weight[(t-cur_NN)*item_num:]
        pred_weight_cur = pred_weight[(t-cur_NN)*item_num:]
        A_t, b_t, G_t, h_t = gen_constraints_latter_months(cur_month_num, x_prev_sol)
        t_updated_x, t_updated_y, t_updated_z = get_updated_plan_for_each_month(cur_month_num, prev_prof, prev_cost, pred_value_cur, pred_weight_cur, true_value_cur, true_weight_cur, A_t, b_t, G_t, h_t)
#        print(prev_prof, prev_cost, t_updated_x)
    
    final_prof = prev_prof + np.dot(true_value[(total_month_num-cur_NN-1)*item_num:], t_updated_x) + np.dot(true_weight[(total_month_num-cur_NN-1)*item_num:], t_updated_x)
    final_cost = prev_cost + trans_fee_percent*np.dot(true_weight[(total_month_num-cur_NN-1)*item_num:], t_updated_z) + np.dot(true_weight[(total_month_num-cur_NN-1)*item_num:], t_updated_y)
    investment_income = final_prof - final_cost
#    print("final_prof: ", final_prof, "final_cost: ", final_cost)
    change_month_num(total_month_num - cur_NN)
    
    return investment_income
