import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from collections import defaultdict

import torch.types
from tqdm import tqdm

from utils.utils import set_seed, get_promt_longeval_topics, get_promt_longeval_lines
from datas.get_data import get_data
from torch.utils.data import DataLoader
from utils.utils import get_model
from utils.utils import get_promt
import torch.nn.functional as F
from utils.utils import compare_retrieval_acc

import numpy as np
import os

model_custom_config = {
    "max_new_tokens": 40,
    "temperature": 0.0,
    "top_p": 0.9
}

def main(args):

    if "weave-mpt1" == args.method:
        import methods.weave_mpt1 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        import models.mpt_7b.weave_attention as weave_attention
        weave_attention.chunk_width = 2047 # 512 # args.push_mpt

    elif "weave-mpt2" == args.method:
        import methods.weave_mpt2 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
    elif "weave-mpt3" == args.method:
        import methods.weave_mpt3 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
    elif "weave-mpt6" == args.method:
        import methods.weave_mpt6 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
        weave_mpt.chunk_width = args.chunk_width
    elif "weave-mpt7" == args.method:
        import methods.weave_mpt7 as weave_mpt
        weave_mpt.push_pos = args.push_mpt
        weave_mpt.push_width = args.push_width
        weave_mpt.chunk_width = args.chunk_width



    dataset = get_data(args.dataset, args)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    if args.cuda == "auto":
        device = "auto"
    else:
        device = torch.device(int(args.cuda))
    tokenizer, model = get_model(args.model_path, device, method=args.method, args=args)

    if "topics" in args.dataset:
        prefix_prompt = get_promt_longeval_topics(args.model_path)
    elif "lines" in args.dataset:
        prefix_prompt = get_promt_longeval_lines(args.model_path)
    else:
        raise NotImplementedError("No Implement")

    all_length_acc = defaultdict(list)

    pbar = tqdm(dataloader)
    count = 0
    for data in pbar:


        model.eval()
        with torch.no_grad():
            # query = prefix_prompt.format(data["text"][0])
            query = data["text"][0]

            # if "pythia" not in args.model_path:
            if "lines" in args.dataset:
                # query += f'\n\nLine <{data["random_idx"][0]}>: <REGISTER_CONTENT> is'
                question = f'\n\nWhat is the <REGISTER_CONTENT> on Line {str(data["random_idx"][0][0]) + " on record " + str(int(data["random_idx"][1]))} ?'
                prompt_query = prefix_prompt.format(question, query)
                query = prompt_query

            if "topics" in args.dataset:
                prompt_query = prefix_prompt.format(query)
                query = prompt_query

            inputs_token = tokenizer(query, return_tensors="pt").to(model.device)
            input_ids = inputs_token.input_ids
            print("input token length: {}".format(len(input_ids[0])))
            prompt_length = len(input_ids[0])
            # if len(input_ids[0]) < 4000: #4096: #15785: #4096: # 14801:
            #     continue

            # before_len = len(input_ids[0])

            outputs = model.generate(input_ids, **model_custom_config)
            response = tokenizer.decode(outputs[0])[len(query):]
            # print("response: {}".format(response))
            # print("target: {}".format(data["target"][0]))
            target = data["target"][0]
            if "topics" in args.dataset:
                summary = f"Label: {target[0]}, Predict: {response}, prompt length: {prompt_length}".replace('\n', ' ')
            else:
                summary = f"Label: {target}, Predict: {response}, prompt length: {prompt_length}".replace('\n', ' ')

            print(summary)

        if "lines" in args.dataset:
            acc = 1 if str(int(target)) in response else 0
            all_length_acc[prompt_length].append(acc)

            with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{'record_' + args.save_file}", "wb") as f:
                pickle.dump({"all_length_acc": all_length_acc}, f)

        # if "topics" in args.dataset:
        output_file = f"{os.path.join(os.getcwd(), args.log_dir)}/{args.save_file}"
        if count == 0:
            with open(output_file, "w") as f:
                f.write(summary)
                f.write("\n")
        else:
            with open(output_file, "a+") as f:
                f.write(summary)
                f.write("\n")

        count += 1


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/data/persist/models/llama-3b")  # mosaicml-mpt-7b llama2-7b-chat
    parser.add_argument("--method", type=str, default="weave-v10")
    parser.add_argument("--dataset", type=str, default="longeval-lines")
    parser.add_argument("--save_file", type=str, default="longeval-lines_old_longeval_test.pkl")
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--log_dir", type=str, default="../logs")
    parser.add_argument("--cuda", type=str, default="0")
    parser.add_argument("--hard_cuda", type=int, default=0)
    parser.add_argument("--seed", type=int, default=0)

    # for mpt-alibi
    parser.add_argument("--push_mpt", type=int, default=512)
    parser.add_argument("--push_width", type=int, default=50)
    parser.add_argument("--chunk_width", type=int, default=512)

    args = parser.parse_args()
    set_seed(args.seed)
    main(args)

















