import argparse
import os
import gzip
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
import glob
import sys
warnings.filterwarnings('ignore')

def get_args():
    parser = argparse.ArgumentParser()
    # Mixing time specific arguments
    parser.add_argument(
        "--seed", help="Random generator seed", type=int, default=0
    )
    parser.add_argument(
        "--tau", help="Task switch time", type=int, default=0
    )
    parser.add_argument(
        "--x", help="Task switch exponent", type=int, default=3
    )
    parser.add_argument(
        "--example", help="Which example?", type=str, default="example_3"
    )
    parser.add_argument(
        "--n-tasks", help="How many tasks", type=int, default=7
    )
    parser.add_argument(
        "--env-transition-type", help="Env transition type", type=str, default="random"
    )
    parser.add_argument(
        "--max-start-states", help="Number of start states", type=int, default=int(1e3)
    )
    parser.add_argument(
        "--save-path", help="Path to the results log dir", type=str, default="/resultsv2/"
    )
    parser.add_argument(
        "--algo", help="Algorithm", type=str, default="ppo"
    )
    parser.add_argument(
        "--asymptotic-steps",
        help="Number of asymptotic states",
        type=int,
        default=int(1e6),
    )
    parser.add_argument(
        "--reporting",
        help="Number of asymptotic states",
        type=int,
        default=int(1e4),
    )
    parser.add_argument(
        "--frequency",
        help="Frequency of updates",
        type=int,
        default=int(1e2),
    )
    parser.add_argument(
        "--percent",
        help="Frequency of updates",
        type=float,
        default=0.1,
    )
    parser.add_argument(
        "--use-uniform-policy",
        action="store_true",
        default=False,
        help="Use uniform policies",
    )
    parser.add_argument(
        "--only-accumulate-returns",
        action="store_true",
        default=False,
        help="Only run accumulate_returns",
    )
    parser.add_argument(
        "--only-asymptotic-returns",
        action="store_true",
        default=False,
        help="Only run asymptotic_returns",
    )
    parser.add_argument(
        "--load-csv",
        action="store_true",
        default=False,
        help="Load prev CSV",
    )
    args = parser.parse_args()
    return args

def mixing_time(results, seed, epsilon, asymptotic_reward_rate):
    all_mixing_times = []
    print(sys.getsizeof(results))
    for key in results.keys():
        reward_list = results[key]
        current_mixing_times = []
        rolling_return = 0
        # Loop through each i in reward list
        for rew_idx in range(len(reward_list)):
            rolling_return += reward_list[rew_idx]
            reward_rate = rolling_return / (rew_idx + 1)
            if abs(reward_rate - asymptotic_reward_rate) < epsilon:
                current_mixing_times.append(rew_idx + 1)
            else:
                current_mixing_times = []

        if len(current_mixing_times) < 1:
            # print(f"Empty | Ep : {epsilon} | Key: {key} | seed: {seed}")
            continue
        else:
            all_mixing_times.append(min(current_mixing_times))
    if len(all_mixing_times) < 1:
        return -1
    mixing = sum(all_mixing_times) / len(all_mixing_times)
    return mixing

def latest_checkpoint(root_path):

    file_path = root_path + "/accumulate_return_*"
    all_files = glob.glob(file_path)
    latest_checkpoints = []

    for f in all_files:
        checkpoint = int(f.split('.')[0].split('_')[-1])
        latest_checkpoints.append(checkpoint)
    return max(latest_checkpoints)

args = get_args()
c = 1
d = 10
algos = ["a2c", "dqn", "ppo"]
path = args.save_path + "example_3/{}.csv".format(args.example)
if not args.load_csv:
    df = pd.DataFrame(columns=["Tau", "Algo", "Tasks", "Seed", "Asymptotic Reward", "Mixing Times", "Epsilon", "Checkpoint"])
else:
    if os.path.isfile(path):
        print("Loading CSV")
        df = pd.read_csv(path, encoding='UTF-8', sep="\t")
    else:
        df = pd.DataFrame(columns=["Tau", "Algo", "Tasks", "Seed", "Asymptotic Reward", "Mixing Times", "Epsilon", "Checkpoint"])

seeds = list(range(5, 20, 5))

if args.example == "example_2":
    print("=" * 50)
    print(f"Gathering the results for {args.example} and Tau = {args.tau}")
    for task in range(2, 8):
        tau = args.tau
        print("=" * 50)
        print(f"Current Task: {task}")
        for algo in algos:
            asymptotic_reward_rate = {}
            rolling_avg = 0
            path = f"{args.example}/{algo}/n_{task}"
            for seed in seeds:
                file_not_found = False
                print(f"Gathering the Asymptotic Returns - Seed: {seed} ")

                file_path = args.save_path + path + f"/seed_{seed}/asymptotic_reward.txt"
                if os.path.isfile(file_path):
                    with open(file_path, 'r') as f:
                        results = f.readlines()[-1]
                        rew = float(results.split()[-1])
                else:
                    try:
                        file_path = args.save_path + path + f"/seed_{seed}/asymptotic_reward.pkl.gz"
                        print(file_path)
                        with gzip.open(file_path, 'rb') as f:
                            results = pickle.load(f)
                            rew = results["Asymptotic Return"]
                    except:
                        print(f"File Not Found \n File Path is: {file_path}")
                        file_not_found = True
                        continue
                print(f"Gathering the Mixing Times - Seed: {seed} ")
                root_path = args.save_path + path + f"/seed_{seed}/"
                last_checkpoint = latest_checkpoint(root_path)
                file_path = root_path + f"accumulate_return_{last_checkpoint}.pkl.gz"
                if not os.path.isfile(file_path):
                    continue
                with gzip.open(file_path, 'rb') as f:
                    results = pickle.load(f)
                if file_not_found:
                    print("Skipping the file: {}".format(file_path))
                    print("Skipping this seed : {}".format(seed))
                    continue

                mixing = mixing_time(
                    results=results,
                    seed=seed,
                    epsilon=rew * args.percent,
                    asymptotic_reward_rate=rew,
                )
                if mixing < 0:
                    print(f"Skipped : Task: {task} | Algo {algo} | Seed: {seed}")
                    continue
                df = df.append(
                    {
                        'Tau': tau,
                                'Algo': algo,
                                'Tasks': task,
                                "Seed": seed,
                                "Asymptotic Reward": rew,
                                "Mixing Times": mixing,
                                "Epsilon": args.percent,
                                "Checkpoint": last_checkpoint
                    },
                    ignore_index=True
                )
                print(df.head(n=60))

                df.to_csv(args.save_path + f"{args.example}/{args.example}_{task}.csv", sep='\t', index=False)
                print("-" * 40)
                file_not_found = False
            df.to_csv(args.save_path + f"{args.example}/{args.example}.csv", sep='\t', index=False)
elif args.example == "example_3":
    tau = args.tau
    print("=" * 50)
    print(f"Gathering the results for {args.example} and Tau = {tau}")
    for algo in algos:
        print("=" * 50)
        print(f"Current Algorithm: {algo}")
        asymptotic_reward_rate = {}
        rolling_avg = 0
        path = f"example_3/{algo}/tau_{tau}"
        for seed in seeds:
            file_not_found = False
            print(f"Gathering the Asymptotic Returns - Seed: {seed} ")

            file_path = args.save_path + path + f"/seed_{seed}/asymptotic_reward.txt"
            if os.path.isfile(file_path):
                with open(args.save_path + path + f"/seed_{seed}/asymptotic_reward.txt", 'r') as f:
                    results = f.readlines()[-1]
                    rew = float(results.split()[-1])
            else:
                try:
                    file_path = args.save_path + path + f"/seed_{seed}/asymptotic_reward.pkl.gz"
                    with gzip.open(file_path, 'rb') as f:
                        results = pickle.load(f)
                        rew = results["Asymptotic Return"]
                except:
                    print(f"File Not Found \n File Path is: {file_path}")
                    file_not_found=True
                    continue
            print(f"Gathering the Mixing Times - Seed: {seed} ")
            root_path = args.save_path + path + f"/seed_{seed}/"
            last_checkpoint = latest_checkpoint(root_path)
            file_path = root_path + f"accumulate_return_{last_checkpoint}.pkl.gz"
            with gzip.open(args.save_path + path + f"/seed_{seed}/accumulate_return_{last_checkpoint}.pkl.gz", 'rb') as f:
                results = pickle.load(f)
            print(sys.getsizeof(results))
            if file_not_found:
                print("*" * 50)
                print("Skipping the file: {}".format(file_path))
                print("Skipping this seed : {}".format(seed))
                print("*" * 50)
                continue
            mixing = mixing_time(
                results=results,
                seed=seed,
                epsilon=rew * args.percent,
                asymptotic_reward_rate=rew,
            )
            df = df.append(
                pd.DataFrame(
                    {
                        'Tau': tau,
                        'Algo': algo,
                        'Tasks': 7,
                        "Seed": seed,
                        "Asymptotic Reward": rew,
                        "Mixing Times": mixing,
                        "Epsilon": args.percent,
                        "Checkpoint": last_checkpoint
                    },
                    index=[0]
            ),
                ignore_index=True
            )
            print(df.head(n=20))
            df.to_csv(args.save_path + f"{args.example}/{args.example}_{args.tau}.csv", sep='\t', index=False)
            print("-" * 40)
            file_not_found = False
        df.to_csv(args.save_path + f"{args.example}/{args.example}.csv", sep='\t', index=False)

else:
    percents = [0.5, 0.3, 0.2, 0.15, 0.1, 0.05]
    xs = [4]
    for percent in percents:
        for x in xs:
            tau = c * (d ** x)
            last_checkpoint = int(tau * 1000)
            for algo in algos:
                asymptotic_reward_rate = {}
                rolling_avg = 0
                path = f"example_3/{algo}/tau_{tau}"
                for seed in seeds:
                    with gzip.open(args.save_path + path + f"/seed_{seed}/asymptotic_reward.pkl.gz", 'rb') as f:
                        results = pickle.load(f)
                    asymptotic_reward_rate[seed] = {tau: results["Asymptotic Return"]}
                print(asymptotic_reward_rate, algo)
                seeds = [5, 10, 15]
                for seed in seeds:
                    with gzip.open(args.save_path + path + f"/seed_{seed}/accumulate_return_{last_checkpoint}.pkl.gz", 'rb') as f:
                        results = pickle.load(f)
                    mixing = mixing_time(
                        results=results,
                        seed=seed,
                        epsilon=asymptotic_reward_rate[seed][tau] * percent,
                        asymptotic_reward_rate=asymptotic_reward_rate[seed][tau],
                    )
                    df = df.append(pd.DataFrame({'Tau': tau,
                                    'Algo': algo,
                                    'Tasks': 7,
                                    "Seed": seed,
                                    "Asymptotic Reward": asymptotic_reward_rate[seed][tau],
                                    "Mixing Times": mixing,
                                    "Epsilon": percent}, index=[0]),
                                   ignore_index=True)

                    print(df.head(n=20))
                    df.to_csv(args.save_path + f"{args.example}/epsilons.csv", sep='\t')


    # epsilons = [0.03141533333333333 * 0.1, 0.03141533333333333 * 0.06]
    #
    # epsilon_labels = ["10 Percent", "5 Percent", "1 Percent", "0.1 Percent", "0.05 Percent"]
# mean_mx = []
#
# for ep_idx, ep in enumerate(epsilons):
#     for seed_idx, path in enumerate(paths):
#
#     df = df.append({ 'Epsilon' : epsilon_labels[ep_idx], 'Metric' : metrics[0], 'Mixing Time' : mixing},
#                 ignore_index = True)
#
# df.pivot("Epsilon", "Metric", "Mixing Time").plot(kind='bar')
# plt.ylabel("Mean Mixing Time")
# plt.title("Epsilon Return Mixing Time x = 3")
# plt.tight_layout()
# plt.savefig("mixing_time_3.png")
#
#                                                                                                                                                                                                                                    82,9          Bot
