import configparser
import json

import matplotlib.pyplot as plt
import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from utils.utils import set_seed
import numpy as np
import os

import matplotlib.patheffects as pe

linestyles = ['--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted', '-', '--', '-.', ':', 'solid', 'dashed', 'dashdot', 'dotted']
markers = ['o', '^', 's', 'D', 'v', 'x', '*', 'p', 'o', '^', 's', 'D', 'v', 'p', '*', 'x']
colors = ['#797979', '#A9B98B', '#98B8DD', '#54936D', '#BD9273', '#CCA29F', '#F3975F']


labels = [
    "Origin",
    "ReRoPE",
    "Leaky-ReRoPE",
    "Dynamic-NTK",
    "LM-Infinite",
    "Streaming-LLM",
    "Mesa-Extrapolation",
    "other"
]

label_setting = {}
for label, linestyle, marker, color in zip(labels, linestyles, markers, colors):
    label_setting[label] = {
        "linestyle": linestyle,
        "marker": marker,
        "color": color
    }

print(label_setting)

def read_config_file(config_path):
    if ".ini" in config_path:
        config = configparser.ConfigParser()
        config.read(config_path)
    elif ".json" in config_path:
        with open(config_path, "r") as f:
            config = json.load(f)
    else:
        raise NotImplementedError("No implement read")
    return config



def main(args=None):

    if args.config_file:
        config = read_config_file(args.config_file)
        files = config['General']['files']
        labels = config['General']['labels']
        model_name = config['General']['model_name']
        smooth_gamma = config['General']['smooth_gamma']
    else:
        raise FileNotFoundError("no config file")

    fig, ax1 = plt.subplots(figsize=(15, 5))
    # 旋转x轴标签
    plt.xticks(rotation=0, ha='center')

    ax2 = ax1.twinx()

    # ax2 = ax1.twinx()


    count = -2
    for _filename, _label in zip(files, labels):
        hard_cuda = 0 if "hard_cuda-0" in _filename else 1
        with open(_filename, "rb") as f:
            data = pickle.load(f)["token_speed_gpu"]
            x = []
            y = [1]
            gpus = []
            for length, value in data.items():
                x.append(length)
                y.append(
                    value["time"] * (1 - smooth_gamma) + y[-1] * smooth_gamma
                )
                # gpu_ = round(int(value["end_gpu_util"].split("\n")[0]) / 1024, 2)
                if hard_cuda == 1:
                    gpu_ = round(int(value["end_gpu_util"].split("\n")[0]) / 1024, 2)
                else:
                    str_list = [num.strip() for num in value["end_gpu_util"].split("\n") if num]
                    int_list = [int(num) for num in str_list if num]
                    gpu_ = round(sum(int_list) / 1024, 2)
                gpus.append(gpu_)
            y.pop(0)

            x = np.array(x)

            if _label in label_setting.keys():
                plotline, = ax2.plot(x, y, marker=label_setting[_label]["marker"], label=_label, linewidth=2, linestyle=label_setting[_label]["linestyle"],
                         color=label_setting[_label]["color"], markersize=6.5)
            else:
                plotline, =ax2.plot(x, y, marker=label_setting["other"]["marker"], label=_label, linewidth=2, linestyle=label_setting["other"]["linestyle"],
                         color=label_setting["other"]["color"], markersize=6.5)
            ax1.bar(x + count * 100 - 50, gpus, width=100, label=_label, color=plotline.get_color())
        count += 1

    # 设置坐标轴标签
    ax1.set_xlabel("Token Length", fontsize=16)
    ax2.set_ylabel("Latency (s)", fontsize=16)

    # 设置纵坐标轴刻度大小
    ax1.tick_params(axis='y', labelsize=13)
    ax2.tick_params(axis='y', labelsize=13)


    ax1.set_ylabel("Memory Usage (GB)", fontsize=16)

    ax1.yaxis.grid(True, linewidth=0.5, alpha=0.5)  # 添加水平方向的网格线，并调整密度

    ax1.set_xticks(x, [str(int(l / 1024)) + "k" if int(l / 1024) % 3 == 2 else "" for l in x], fontsize=14)

    # ax2.tick_params('y', colors='r')
    # 设置标题
    # plt.title("Speed & Memory Used: {}".format(model_name), fontsize=14)
    # plt.ylim(0, 1.2)

    plt.subplots_adjust(left=0.06, right=0.94, bottom=0.14, top=0.92)

    # plt.legend()

    ax2.legend(loc='upper right', fontsize=13)
    ax1.legend(loc='upper left', fontsize=13)
    # ax2.legend(fontsize=14)
    # ax1.legend(fontsize=14)


    # 优化子图效果，添加浅色边框
    ax1.spines['top'].set_color('lightgrey')
    ax1.spines['right'].set_color('lightgrey')
    ax1.spines['bottom'].set_color('lightgrey')
    ax1.spines['left'].set_color('lightgrey')
    ax2.spines['top'].set_color('lightgrey')
    ax2.spines['right'].set_color('lightgrey')
    ax2.spines['bottom'].set_color('lightgrey')
    ax2.spines['left'].set_color('lightgrey')

    ax1.spines['right'].set_path_effects([pe.withStroke(linewidth=2, foreground='lightgrey')])  # 添加立体效果
    ax1.spines['bottom'].set_path_effects([pe.withStroke(linewidth=2, foreground='lightgrey')])
    # ax1.set_facecolor('#f8f8f8')  # 设置背景色

    plt.savefig("speed_memory_{}.png".format(model_name))

    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_file", type=str, default="../conf/speed-memory-result4.json")
    # default="../conf/speed-memory-result4.json"
    # default="../conf/speed-memory-result6.json"
    args = parser.parse_args()
    main(args)
