import datetime
import math
import simpy
import random
import numpy as np
import sys
import matplotlib
matplotlib.use('TkAgg')  # Adjust based on your OS and environment
import matplotlib.pyplot as plt
import pandas as pd
from real_dataset.parse_datasets import load_and_prepare_data

response_times = []
response_pshort_times = []
response_plong_times = []
short_queue = []
long_queue = []
response_plong_times_details = []

start_time = None
job_processes = {}
n_cheap_p = 0
n_expensive_p = 0
SIMULATION_TIME = 100000
#SIMULATION_TIME = 1000000
#T= float('inf') #to test FCFS
#T= 0 #to test SPRPT
#T = 4
current_job = None
LOG_EVENT_PRINT = 0
PLOT_GRAPHS = 0

#predictor = 'expP'
#predictor = 'perfectP'
#predictor = 'uniP'

#dist = 'weibull'
#dist = 'exponential'


#Real dataset
dist = 'real'
predictor = 'real'

real_data_index = 0

cheap_alpha = 0.8
#alpha = 2


class EndOfDataException(Exception):
    pass

def log_event(time, event, job_id="--", size="--", predicted_size="--", notes="--", queue_content="--"):
    if LOG_EVENT_PRINT:
        time_str = f"{time:.5f}"
        job_id = "--" if job_id == "--" else f"{job_id:.2f}"
        size = "--" if size == "--" else f"{size:.2f}"
        predicted_size = "--" if predicted_size == "--" else f"{predicted_size:.2f}"
        notes = "--" if notes == "--" else "DONE" if notes == "DONE" else f"{notes:.2f}"
        print(f"{time_str:<10}| {event:<20}| {job_id:<10}| {size:<10}| {predicted_size:<15}| {notes:<20}| {queue_content}")

def short_job_size_distribution(): #TODO: change
    global T
    while True:
        sample = job_size_distribution()
        if sample < T:
            return sample

def job_size_distribution():
    global real_data_index

    if dist == 'exponential':
        return random.expovariate(1), 0

    if dist == 'weibull':
        U = random.random()
        return (-math.log(1 - U))**2 / 2, 0

    elif dist == 'real':
        # Return job size from real dataset, and handle index increment
        if real_data_index < len(real_data): #TODO
            job_size = real_data.loc[real_data_index, 'normalized_runtime']
            predicted_job_size = real_data.loc[real_data_index, 'normalized_predicted_runtime']

            real_data_index += 1
            return job_size, predicted_job_size
        else:
            raise EndOfDataException("End of data reached")

    else:
        raise ValueError("Unknown distribution type specified.")


def predict_service_time(job, uni_alpha):
    if predictor == 'perfectP' or uni_alpha == 0:
        return job

    if predictor == 'expP':
        return random.expovariate(1 / job)

    if predictor == 'uniP':
        lower_bound = (1 - uni_alpha) * job
        upper_bound = (1 + uni_alpha) * job
        return random.uniform(lower_bound, upper_bound)

############### Different types of predictors  #################

#def predict_service_time(job): #TODO: perfect predictor
#    return job

#def predict_service_time(job): #TODO uniP predictor
#    lower_bound = (1 - alpha) * job
#    upper_bound = (1 + alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_service_time(z): #TODO: exponential predictor
#     return random.expovariate(1/z)

#def predict_uniP_cheap(job): #TODO cheap uniP predictor
#    global cheap_alpha
#    lower_bound = (1 - cheap_alpha) * job
#    upper_bound = (1 + cheap_alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_uniP_expensive(job):  # TODO expensive uniP predictor
#    global expensive_alpha
#    lower_bound = (1 - expensive_alpha) * job
#    upper_bound = (1 + expensive_alpha) * job
#    return random.uniform(lower_bound, upper_bound)
#####################################################


def check_schedule(env, server):
    global current_job, start_time

    short_queue_str = [f"({x:.2f}, {y:.2f})" for x, y in short_queue]
    long_queue_str = [f"({x:.2f}, {y:.2f} , {z:.2f})" for x, y,z in long_queue]
    queue_str = f"Short: {short_queue_str}, Long: {long_queue_str}"
    log_event(env.now, "Checking Policy", "--", "--", "--", "--", queue_str)

    if short_queue:
        log_event(env.now, "Schedule Short Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, server, preemptive=False))
    elif long_queue:
        log_event(env.now, "Schedule long Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, server, preemptive=True))

def serve_job(env, server, preemptive):
    global current_job, start_time

    if preemptive:  # long job
        with server.request(priority=1) as req:
            yield req
            if not long_queue:
                return
            next_job = long_queue.pop(0)
            log_event(env.now, "Serve Long", next_job[0], next_job[2], "--", next_job[0], f"Age: {next_job[1]}")
            start_time = env.now
            current_job = (next_job[0], next_job[1], next_job[2])
            current_process = env.process(serve_long_job(env, next_job))
            job_processes[start_time] = current_process
            try:
                yield current_process

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Short Job Arrival":
                    log_event(env.now, "Short Preempted Long", job_id=current_job[0])
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1] + elapsed_time, current_job[2])
                    long_queue.append(updated_job)
    else:  # short job
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Short Job Arrival")

        with server.request(priority=0) as req:
            yield req
            if not short_queue:
                return
            next_job = short_queue.pop(0)
            log_event(env.now, "Serve Short", next_job[0], next_job[1], "--", next_job[0], "--")
            yield env.timeout(next_job[1])
            log_event(env.now, "Short Job Done", next_job[0], "--", "--", "DONE", "--")
            response_times.append(env.now - next_job[0])
            response_pshort_times.append(env.now - next_job[0])

    current_job = None
    check_schedule(env, server)


def serve_long_job(env, job):
    actual_remaining_time = job[2] - job[1] # size - age
    yield env.timeout(actual_remaining_time)
    log_event(env.now, "Long Job Done", job_id=job[0])
    response_time = env.now - job[0]
    job_details = {
        'response_time': response_time,
        'actual_size': job[1]
    }
    response_plong_times_details.append(job_details)
    response_times.append(response_time)
    response_plong_times.append(response_time)
    global current_job, start_time
    current_job = None
    del job_processes[start_time]
    start_time = None


def job_generator(env, server, arriving_rate):
    global n_cheap_p, n_expensive_p
    global T
    while True:
        yield env.timeout(random.expovariate(arriving_rate))
        job = env.now
        try:
            job_size, cheap_predicted_service_time = job_size_distribution()
        except EndOfDataException:
            print("End of data reached. Terminating simulation.")
            break

        if dist == 'exponential' or dist == 'weibull':
            cheap_predicted_service_time = predict_service_time(job_size, cheap_alpha)

        n_cheap_p += 1
        log_event(env.now, "Job Arrival", job, job_size, cheap_predicted_service_time, "--", f"Threshold: {T:.3f}")

        #if predicted_service_time <= T: #TODO
        if cheap_predicted_service_time < T:
            short_queue.append((job,job_size))
            log_event(env.now, "Append to Short", job, "--", "--", "--", "--")

        else:
            n_expensive_p += 1
            long_queue.append((job, 0, job_size))
            log_event(env.now, "Append to Long", job,"--", "--", "--", "--")

        check_schedule(env, server)


def dummy_job_generator(env, server):
   global n_cheap_p, n_expensive_p

   # First job - Short
   yield env.timeout(1)
   job = env.now
   job_size = short_job_size_distribution()
   predicted_service_time = predict_service_time(job_size)
   n_cheap_p += 1
   log_event(env.now, "Job Arrival", job, job_size, predicted_service_time, "--", f"Threshold: {T:.3f}")
   short_queue.append((job, job_size))
   log_event(env.now, "Append to Short", job, "--", "--", "--", "--")
   check_schedule(env, server)

   # Second job - Long
   yield env.timeout(2)
   job = env.now
   job_size = job_size_distribution()
   while job_size <= T:
       job_size = job_size_distribution()
   predicted_service_time = predict_service_time(job_size)
   n_cheap_p += 1
   n_expensive_p += 1
   log_event(env.now, "Job Arrival", job, job_size, predicted_service_time, "--", f"Threshold: {T:.3f}")
   long_queue.append((job, 0, job_size))
   log_event(env.now, "Append to Long", job, "--", predicted_service_time, "--", "--")
   check_schedule(env, server)

   yield env.timeout(job_size - 0.5)
   job = env.now
   job_size = short_job_size_distribution()
   predicted_service_time = predict_service_time(job_size)
   n_cheap_p += 1
   log_event(env.now, "Job Arrival", job, job_size, predicted_service_time, "--", f"Threshold: {T:.3f}")
   short_queue.append((job, job_size))
   log_event(env.now, "Append to Short", job, "--", "--", "--", "--")
   check_schedule(env, server)



def run_simulation(arrival_rate, threshold, c1):
    global response_times, response_pshort_times, response_plong_times, response_plong_times_details
    global start_time, job_processes, SIMULATION_TIME
    global T
    global n_cheap_p, n_expensive_p
    global short_queue, long_queue
    global current_job
    # Re-initialize the lists at the start of each run
    response_times = []
    response_pshort_times = []
    response_plong_times = []
    response_plong_times_details = []
    short_queue = []
    long_queue = []

    T = threshold
    n_cheap_p = 0
    n_expensive_p = 0

    start_time = None
    job_processes = {}
    SIMULATION_TIME = 1000000

    print(f'Simulating M/M/1 queue with small advice, FCFS of short/long jobs, external cost model, lambda: {arrival_rate}, T:{T}')
    env = simpy.Environment()
    server = simpy.PriorityResource(env, capacity=1)
    current_job = None
    print(f"{'TIME':<10} | {'EVENT':<18} | {'JOB ID':<10} | {'SIZE':<8} | {'PREDICTED SIZE':<15} | {'SERVER':<18} | {'NOTES'}")
    print("-" * 145)
    env.process(job_generator(env, server, arrival_rate))
    #env.process(dummy_job_generator(env, server))


    env.run(until=SIMULATION_TIME)
    #print(f'response_times: {response_times}')
    mean_response_time_pshort = 0
    mean_response_time_plong = 0
    print(f'n_cheap_p: {n_cheap_p}')
    print(f'n_expensive_p: {n_expensive_p}')


    print(f"\nMean Response Time without costs: {(sum(response_times)/ len(response_times)):.2f} time units")
    mean_response_time = (sum(response_times) + (n_cheap_p +n_expensive_p ) * c1) / len(response_times)
    print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, T:{T}")

    print(f"\nPredicted short: Mean Response Time without costs: {(sum(response_pshort_times)/ len(response_pshort_times)):.2f} time units")
    mean_response_time_pshort = (sum(response_pshort_times) + n_cheap_p * c1) / len(response_pshort_times)
    print(f"\nPredicted short: Mean Response Time with costs: {mean_response_time_pshort:.2f} time units lambda: {arrival_rate}, T:{T}")

    print(f"\nPredicted long: Mean Response Time without costs: {(sum(response_plong_times)/ len(response_plong_times)):.2f} time units")
    mean_response_time_plong = (sum(response_plong_times) + n_expensive_p * c1) / len(response_plong_times)
    print(f"\nPredicted long: Mean Response Time with costs: {mean_response_time_plong:.2f} time units lambda: {arrival_rate}, T:{T}")

    ##########################

    # Original mean response time calculation (including the entire list)
    mean_response_time_full_list = sum(response_plong_times) / len(response_plong_times)
    print(f"\nPredicted long: Mean Response Time for full list: {mean_response_time_full_list:.2f} time units")

    # Calculate mean response time with costs for the full list
    mean_response_time_with_costs_full_list = (sum(response_plong_times) + n_expensive_p * c1) / len(response_plong_times)
    print(f"\nPredicted long: Mean Response Time with costs for full list: {mean_response_time_with_costs_full_list:.2f} time units lambda: {arrival_rate}, T:{T}")

    # Calculate the index to start from (10% of the list length)
    start_index = int(len(response_plong_times) * 0.1)

    # Sliced list excluding the first 10%
    sliced_response_plong_times = response_plong_times[start_index:]

    # Calculate the mean response time without considering the first 10%
    mean_response_time_without_first_10 = sum(sliced_response_plong_times) / len(sliced_response_plong_times)
    print(f"\nPredicted long: Mean Response Time without first 10%: {mean_response_time_without_first_10:.2f} time units")

    mean_response_time_with_costs_excluding_first_10 = (sum(sliced_response_plong_times) + n_expensive_p * c1) / len(sliced_response_plong_times)
    print(f"\nPredicted long: Mean Response Time with costs (excluding first 10%): {mean_response_time_with_costs_excluding_first_10:.2f} time units lambda: {arrival_rate}, T:{T}")

    mean_response_time_plong = mean_response_time_without_first_10


    return mean_response_time, mean_response_time_pshort, mean_response_time_plong, response_plong_times_details


def simulation_wrapper(arrival_rate, threshold, c1):
    global response_times, response_pshort_times, response_plong_times, short_queue, long_queue, start_time, job_processes, n_cheap_p, n_expensive_p

    start_time = None
    job_processes = {}
    n_cheap_p = 0
    n_expensive_p = 0
    response_times = []
    response_pshort_times = []
    response_plong_times = []
    short_queue = []
    long_queue = []

    # Initialize lists to store results
    mean_response_times = []
    mean_response_times_pshort = []
    mean_response_times_plong = []

    # Run the simulation 100 times
    for _ in range(100):#TODO change to 100
        mean_response_time, mean_response_time_pshort, mean_response_time_plong , response_plong_times_details = run_simulation(arrival_rate, threshold, c1)
        mean_response_times.append(mean_response_time)
        mean_response_times_pshort.append(mean_response_time_pshort)
        mean_response_times_plong.append(mean_response_time_plong)

    # Calculate and print the average of each
    average_mean_response_time = sum(mean_response_times) / len(mean_response_times)
    average_mean_response_time_pshort = sum(mean_response_times_pshort) / len(mean_response_times_pshort)
    average_mean_response_time_plong = sum(mean_response_times_plong) / len(mean_response_times_plong)

    print(f"Average Mean Response Time: {average_mean_response_time:.2f} time units")
    print(f"Average Predicted Short Mean Response Time: {average_mean_response_time_pshort:.2f} time units")
    print(f"Average Predicted Long Mean Response Time: {average_mean_response_time_plong:.2f} time units")


    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")

    #filename = f'res/1bit_simulation_res_{current_date}.csv'

    #with open(filename, 'w') as file:
    #    file.write(f"arrival_rate:{arrival_rate}, T:{T}, c1:{c1}\n")
    #    file.write(f"mean_response_time:{mean_response_time}, mean_response_time_pshort:{mean_response_time_pshort}, mean_response_time_plong:{mean_response_time_plong}\n")


    with open(f'long_res/1bit_long_job_response_{predictor}_times_{arrival_rate}_T_{threshold}_{current_date}.csv', 'w') as file:
        # Writing the headers
        file.write("Response Time,Actual Size\n")

        # Writing the job details
        for job_detail in response_plong_times_details:
            line = f"{job_detail['response_time']},{job_detail['actual_size']}\n"
            file.write(line)

    return average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong

#####
markers = ['o', 'x', '^', 's', 'D', 'p']  # You can add more markers as needed
colors = ['b', 'g', 'r', 'c', 'm', 'y']  # Basic color abbreviations: b-blue, g-green, r-red, etc.


def test_cost_vs_T():
    global real_data_index
    T_values = [1]
    T_test_values = [0.1, 0.5, 1, 1.5, 2, 4, 5, 8]
    default_arrival_rate = 0.9
    fixed_c1 = 0.5  # TODO
    fixed_c2 = 2  # TODO
    labels = {1: '1bit'}

    if dist == 'real':
        T_values = [4]
        labels = {4: '1bit'}


    # Running the test
    results = []
    for t in T_values:
        for test_value in T_test_values:
            real_data_index = 0
            average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong = simulation_wrapper(default_arrival_rate, test_value, fixed_c1)
            results.append({
                'Alg': labels[t],
                'T Value': test_value,
                'Average Mean Response Time': average_mean_response_time,
                'Average PShort Response Time': average_mean_response_time_pshort,
                'Average PLong Response Time': average_mean_response_time_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': fixed_c1,
                'Default calpha': cheap_alpha,
                'Default expalpha': 0
            })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_T_results_{predictor}_alpha_{cheap_alpha}_dist_{dist}_dataset_{dataset}_1bit_0.8.csv', index=False)

    if PLOT_GRAPHS:
        # Plotting the results
        for (t, marker, color) in zip(T_values, markers, colors):
            label_t = labels[t]
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['T Value'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel(r'$T$')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/cost_vs_T_{predictor}_alpha_{cheap_alpha}_1bit_{dist}.png')
        plt.clf()


def test_cost_vs_arrivalrate():
    global real_data_index

    # Test parameters
    arrival_rate_values = [0.5, 0.6, 0.7, 0.9, 0.95]
    T_values = [1]
    fixed_c1 = 0.5
    labels = {1: '1bit'}


    if dist == 'real':
        T_values = [4]
        labels = {4: '1bit'}


    # Running the test
    results = []
    for t in T_values:
        for arrival_rate in arrival_rate_values:
            real_data_index = 0
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(arrival_rate, t, fixed_c1)
            results.append({
                'Alg': labels[t],
                'Arrival Rate': arrival_rate,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default T': t,
                'Default c1': fixed_c1,
                'Default c2': 0,
                'Default calpha': cheap_alpha,
                'Default expalpha': 0
            })


    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_arrivalrate_results_{predictor}_alpha_{cheap_alpha}_dist_{dist}_dataset_{dataset}_1bit.csv', index=False)

    if PLOT_GRAPHS:
        for (t, marker, color) in zip(T_values, markers, colors):
            label_t = labels[t]
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['Arrival Rate'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel('Arrival Rate')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/cost_vs_arrivalrate_{predictor}_alpha_{cheap_alpha}_1bit_{dist}.png')
        plt.clf()

if __name__ == "__main__":

    datasets = ['twosigma', 'google',  'trinity']

    for dataset in datasets:
        if dataset == 'twosigma':
            file_path = 'real_dataset/jvupredict_twosigma.csv.gz'
        if dataset == 'google':
            file_path = 'real_dataset/jvupredict_google_all_features.csv.gz'
        if dataset == 'trinity':
            file_path = 'real_dataset/jvupredict_trinity.csv.gz'

        real_data = load_and_prepare_data(file_path)

        # TESTS
        test_cost_vs_T()
        test_cost_vs_arrivalrate()

