import os
import sys
from collections import defaultdict
import numpy as np
import gurobipy as gp
from gurobipy import GRB
from sklearn.metrics import mean_squared_error
from numpy import inf
import torch
import math
import time
import linear_relax as LP_relax_file

total_month_num = LP_relax_file.total_month_num
month_num = LP_relax_file.month_num
x_num = LP_relax_file.x_num
y_num = LP_relax_file.y_num
var_num = LP_relax_file.var_num

featureNum = 4096
train_set_size = 70
train_curr_profit = np.zeros((train_set_size, total_month_num))
train_curr_stocking = np.zeros((train_set_size, total_month_num))

small_or_large = int(sys.argv[1])
startmark = int(sys.argv[2])
endmark = int(sys.argv[3])

simulation_time = 10
dataset_path = os.path.abspath(os.path.dirname(os.getcwd()))
default_path = os.path.join(dataset_path, 'data/month_num=' + str(month_num) + '/')

if small_or_large == 0:
  print("small price", end=', ')
  LP_relax_file.mkdir(default_path, 'small')
  TOV_path = os.path.join(default_path,'small/')
  LP_relax_file.mkdir(TOV_path, 'true_profit')
  LP_relax_file.mkdir(TOV_path, 'true_stocking')
elif small_or_large == 1:
  print("large price", end=' ')
  LP_relax_file.mkdir(default_path, 'large')
  TOV_path = os.path.join(default_path,'large/')
  LP_relax_file.mkdir(TOV_path, 'true_profit')
  LP_relax_file.mkdir(TOV_path, 'true_stocking')


def make_next_plan(test_num, cur_NN, true_demand, pred_demand, price, cost):
    global train_curr_profit
    global train_curr_stocking
        
    if cur_NN == 0:
        init_x, init_y = LP_relax_file.get_init_plan(price, cost, pred_demand, true_demand)
        train_curr_profit[test_num][cur_NN] = - cost[0] * init_x[0]
        train_curr_stocking[test_num][cur_NN] = init_x[0]
#        print(cur_NN, init_x[0])
        
    else:
#        demand = np.concatenate([true_demand[0], pred_demand[1:]], axis=0)
        demand = np.zeros(LP_relax_file.month_num)
        demand[0] = true_demand[0]
#        print(cur_NN, true_demand)
        for i in range(1, LP_relax_file.month_num):
            demand[i] = pred_demand[i]
        G_t, h_t = LP_relax_file.gen_constraints_latter_days(cur_NN, demand, train_curr_stocking[test_num][cur_NN-1])
        c_t = LP_relax_file.gen_obj_latter_days(cur_NN, price, cost)
        t_updated_x, t_updated_y = LP_relax_file.get_updated_plan_for_each_day(cur_NN, c_t, G_t, h_t)
        
        # compute current states
        new_profit = price[cur_NN] * t_updated_y[0] - cost[cur_NN] * t_updated_x[0]
        train_curr_profit[test_num][cur_NN] = train_curr_profit[test_num][cur_NN-1] + new_profit
        new_stocking = t_updated_x[0] - t_updated_y[0]
        train_curr_stocking[test_num][cur_NN] = train_curr_stocking[test_num][cur_NN-1] + new_stocking
#        print(cur_NN, t_updated_x[0], t_updated_y[0])
    
    train_curr_profit[test_num][cur_NN] = round(train_curr_profit[test_num][cur_NN], 2)
    train_curr_stocking[test_num][cur_NN] = round(train_curr_stocking[test_num][cur_NN], 2)



print("month_num: ", total_month_num)

for testi in range(startmark, endmark):
    print(testi, end=" ")
    cost = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+ str(total_month_num) +'/cost/cost(' + str(testi) + ').txt'))
    if small_or_large == 0:
      price = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+ str(total_month_num) +'/small_price/price(' + str(testi) + ').txt'))
    elif small_or_large == 1:
      price = np.loadtxt(os.path.join(dataset_path, 'data/month_num='+ str(total_month_num) +'/large_price/price(' + str(testi) + ').txt'))
    true_demand_full = np.loadtxt(os.path.join(default_path, 'train_demands/train_demands(' + str(testi) + ').txt'))
    
    for NN_cnt in range(total_month_num):
#        print(NN_cnt)
        cur_month_num = total_month_num
        if NN_cnt == 0:
            LP_relax_file.reset_month_num()
        else:
            cur_month_num = total_month_num - NN_cnt
            LP_relax_file.change_month_num(cur_month_num)
            
        # Compute the NN_cnt plans
        for test_num in range(train_set_size):
#            print(test_num)
            true_demand = np.zeros(LP_relax_file.month_num)
            k = 0
            for j in range(NN_cnt, total_month_num):
                true_demand[k] = true_demand_full[test_num*total_month_num+j]
                k = k + 1
#            print(real_patient)
            make_next_plan(test_num, NN_cnt, true_demand, true_demand, price, cost)
    
    
    train_obj = train_curr_profit[:, total_month_num-1]
    
    np.savetxt(os.path.join(TOV_path, 'true_profit/true_profit(' + str(testi) + ').txt'), train_curr_profit, fmt="%.0f")
    np.savetxt(os.path.join(TOV_path, 'true_stocking/true_stocking(' + str(testi) + ').txt'), train_curr_stocking, fmt="%.2f")
#    print(train_TOV_prev_cost, train_TOV_prev_prof)
    LP_relax_file.reset_month_num()
    print("TOV: ", np.sum(train_obj)/train_set_size)

    # reset
    train_curr_profit = train_curr_profit * 0
    train_curr_stocking = train_curr_stocking * 0
