import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse


import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.colors import TwoSlopeNorm
import torch

"""
# llama2-7b-chat
rerope:
vmin=-1245.1171875, vcenter=0, vmax=1249.0234375
vmin=-1245.1171875, vcenter=0, vmax=1249.0234375

vmin=-1189.453125, vcenter=0, vmax=1249.0234375

weave-v10:llama2-7b-chat:
first_pos=4000, pos_step=1000, head_from=0, head_to=640, global_setting=False
vmin=-17390.625, vcenter=0, vmax=10476.5625

# llama-3b
rerope:
vmin=-800.78125, vcenter=0, vmax=1241.2109375

# vicuna-13b
rerope:
vmin=-2990.234375, vcenter=0, vmax=2935.546875

"""



# llama2-7b-chat
global_min_value = -1189.453125
global_max_value = 1249.0234375

# # llama-3b
# global_min_value = -800.78125
# global_max_value = 1241.2109375

# # vicuna-13b
# global_min_value = -2990.234375
# global_max_value = 2935.546875

global_setting = False
plot_layers = 12

head_from = 0
head_to = 128 * 5
max_coins = 1000

# llama-3b setting
first_pos = 2000
pos_step = 1000
last_pos = 4000

# # llama2-7b-chat setting
# first_pos = 4000
# pos_step = 1000
# last_pos = 7000
#
# # vicuna-13b setting
# first_pos = 2000
# pos_step = 1000
# last_pos = 4000

class ConfigBase(object):
    def get_params(self):
        res = ""
        for name, value in vars(self).items():
            res += "{}-{}_".format(name, value)
        return res

class ConfigPlot(ConfigBase):
    first_pos = 2000
    pos_step = 1000
    last_pos = 4000
    global_setting = False

    def refresh_params(self, model_name):
        global first_pos, pos_step, last_pos, global_max_value, global_min_value, global_setting

        # 在此需要单独设定 True 或 False, 代码存在bug，目前只能手动指定
        global_setting = True



        self.global_setting = global_setting
        if global_setting:
            if "llama-3b" in model_name:
                # # llama-3b
                global_min_value = -800.78125
                global_max_value = 1241.2109375
            elif "llama2-7b-chat" in model_name:
                # llama2-7b-chat
                global_min_value = -1189.453125
                global_max_value = 1249.0234375
            elif "vicuna" in model_name:
                # vicuna-13b
                global_min_value = -2990.234375
                global_max_value = 2935.546875
            else:
                raise NotImplementedError("no implement")
        else:
            if "llama-3b" in model_name:
                # llama-3b setting
                first_pos = 2000
                pos_step = 1000
                last_pos = 4000
            elif "llama2-7b-chat" in model_name:
                # llama2-7b-chat setting
                first_pos = 4000
                pos_step = 1000
                last_pos = 7000
            elif "vicuna" in model_name:
                # vicuna-13b setting
                first_pos = 2000
                pos_step = 1000
                last_pos = 4000
            else:
                raise NotImplementedError("no implement")


        self.first_pos = first_pos
        self.pos_step = pos_step
        self.last_pos = last_pos
        self.global_setting = global_setting

        print("config-inner: first_pos={}, pos_step={}, last_pos={}, head_from={}, head_to={}, global_setting={}".format(first_pos,
                                                                                                           pos_step,
                                                                                                           last_pos,
                                                                                                           head_from,
                                                                                                           head_to,
                                                                                                           global_setting))


def visualize_matrices_all(matrices, filename=None):

    config_plot = ConfigPlot()
    config_plot.refresh_params(filename)

    min_value = np.min([np.min(matrix) for matrix in matrices])
    max_value = np.max([np.max(matrix) for matrix in matrices])

    cmap = plt.get_cmap('seismic')  # 使用'seismic'颜色映射

    # norm = Normalize(vmin=min_value, vmax=max_value)  # 设置颜色映射的范围
    # 设置颜色映射的范围，以0为中心，白色为中间值

    if global_setting:
        min_value = global_min_value
        max_value = global_max_value

    norm = TwoSlopeNorm(vmin=min_value, vcenter=0, vmax=max_value)
    print("vmin={}, vcenter=0, vmax={}".format(min_value, max_value))

    fig, axes = plt.subplots(1, len(matrices), figsize=(30, 8))
    plt.subplots_adjust(left=0.04, right=1, top=0.9, bottom=0.12)

    # 设置统一的 x 轴标签和 y 轴标签
    fig.text(0.5, 0.04, 'Token Position', ha='center', va='center', fontsize=16)
    fig.text(0.01, 0.5, 'Dimension', ha='center', va='center', rotation='vertical', fontsize=16)

    for i, matrix in enumerate(matrices):
        ax = axes[i]
        im = ax.imshow(matrix, cmap=cmap, interpolation='nearest', norm=norm)
        ax.set_title(f'layer: {i}', fontsize=16)
        ax.axis('on')
        if i > 0:
            # ax.axis('off')
            ax.yaxis.set_visible(False)

        # 调整刻度值的字体大小
        ax.tick_params(axis='both', labelsize=14)

        # 设置子图的边框线颜色为浅灰色
        for spine in ax.spines.values():
            spine.set_edgecolor('lightgrey')

    # # fig.colorbar(im, ax=axes, orientation='vertical', shrink=0.6, aspect=10)
    # # 添加颜色条，并设置shrink参数
    cbar = fig.colorbar(im, ax=axes, orientation='vertical', shrink=0.3, aspect=10, fraction=0.04, pad=0.02)

    # 添加颜色条到最后一个子图
    # cbar = plt.colorbar(im, cax=axes[-1], orientation='vertical', shrink=0.5, aspect=10, pad=0.002)

    # 设置子图较大
    for ax in axes.flatten():
        ax.set_aspect('auto')

    # plt.subplots_adjust(left=0.04, right=0.91, top=0.9, bottom=0.12)

    params = config_plot.get_params()
    plt.savefig("{}_{}.png".format(filename, params))
    plt.show()
    plt.close()


def plot_hidden_state(all_hidden_states, filename=None):

    print("first_pos={}, pos_step={}, last_pos={}, head_from={}, head_to={}, global_setting={}".format(first_pos, pos_step, last_pos, head_from, head_to, global_setting))

    matrixs = []
    for layer_ in range(plot_layers):
        # first_20 = all_hidden_states[layer_][0][:40, head_from:head_to]
        # mid_first = all_hidden_states[layer_][0][1024:1024+40, head_from:head_to]
        # mid_20 = all_hidden_states[layer_][0][2048-20:2048+20, head_from:head_to]
        # mid_last = all_hidden_states[layer_][0][-1024-40:-1024, head_from:head_to]
        # last_20 = all_hidden_states[layer_][0][-40:, head_from:head_to]
        # combine = torch.concat([first_20, mid_first, mid_20, mid_last, last_20], dim=0)

        # combine = all_hidden_states[layer_][0][4000:, head_from:head_to]

        # first_20 = all_hidden_states[layer_][0][4000:5000, head_from:head_to]




        # first_20 = all_hidden_states[layer_][0][first_pos:first_pos+pos_step, head_from:head_to]
        # last_20 = all_hidden_states[layer_][0][-pos_step:, head_from:head_to]
        # combine = torch.concat([first_20, last_20], dim=0)

        first_20 = all_hidden_states[layer_][0][first_pos:first_pos+pos_step, head_from:head_to]
        last_20 = all_hidden_states[layer_][0][last_pos:last_pos+pos_step, head_from:head_to]
        combine = torch.concat([first_20, last_20], dim=0)



        # first_20 = all_hidden_states[layer_][0][500:1000, head_from:head_to]
        # mid_20 = all_hidden_states[layer_][0][4000:4500, head_from:head_to]
        # last_20 = all_hidden_states[layer_][0][-500:, head_from:head_to]
        # combine = torch.concat([first_20, mid_20, last_20], dim=0)


        # combine = torch.concat([combine])
        # visualize_matrix(first_20.cpu())
        # visualize_matrix(last_20.cpu())
        # visualize_matrix(np.rot90(combine.cpu()), layer_, head_from)
        matrixs.append(np.rot90(combine.cpu()) * max_coins)

    visualize_matrices_all(matrixs, filename)


if __name__ == "__main__":

    # global global_setting

    parser = argparse.ArgumentParser()
    parser.add_argument("--file_path", type=str, default="../test/old_pile_llama2-7b-chat_saved_all_hidden_states.pth")
    parser.add_argument("--global_setting", type=bool, default=False)
    args = parser.parse_args()

    global_setting = args.global_setting

    # 4月15日探测实验
    # rerope_pile_llama2-7b-chat_saved_all_hidden_states.pth
    # old_pile_llama2-7b-chat_saved_all_hidden_states.pth
    # rerope_llama2-7b-chat_saved_all_hidden_states.pth
    # ../test/old_llama2-7b-chat_saved_all_hidden_states-validate-no-inittoken2.pth
    # rerope_hello_llama2-7b-chat_saved_all_hidden_states.pth
    # rerope_passkey_llama2-7b-chat_saved_all_hidden_states.pth
    # old_passkey_llama2-7b-chat_saved_all_hidden_states.pth

    # filename = "../test/old_llama2-7b-chat_saved_all_hidden_states.pth"
    # filename = "../test/leaky-rerope_llama2-7b-chat_saved_all_hidden_states.pth"
    # filename = "../test/rerope_llama2-7b-chat_saved_all_hidden_states.pth"
    # filename = "../test/weave_v10_llama2-7b-chat_saved_all_hidden_states.pth"

    # filename = "../test/old_llama-3b_saved_all_hidden_states.pth"
    # filename = "../test/rerope_llama-3b_saved_all_hidden_states.pth"
    # filename = "../test/leaky-rerope_llama-3b_saved_all_hidden_states.pth"
    # filename = "../test/weave_v10_llama-3b_saved_all_hidden_states.pth"

    assert args.file_path is not None, "file-path is none"

    filename = args.file_path

    all_hidden_states = torch.load(filename)
    plot_hidden_state(all_hidden_states, filename)