import os, sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
BASE_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

import argparse
import pickle
from collections import defaultdict

import torch.types
from tqdm import tqdm

from utils.utils import set_seed
from datas.get_data import get_data
from torch.utils.data import DataLoader
from utils.utils import get_model
from utils.utils import get_promt
import torch.nn.functional as F

import numpy as np
import os

model_custom_config = {
    "max_new_tokens": 1,
    "temperature": 0.1,
    "top_p": 0.9
}

def main(args):

    dataset = get_data(args.dataset, args)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    if args.cuda == "auto":
        device = "auto"
    else:
        device = torch.device(int(args.cuda))
    tokenizer, model = get_model(args.model_path, device, method=args.method, args=args)

    prefix_prompt = get_promt(args.model_path)

    all_nll = []

    pbar = tqdm(dataloader)
    count = 0
    for data in pbar:

        if count > 10000:
            break
        count += 1



        model.eval()
        with torch.no_grad():
            query = prefix_prompt + data["text"][0] + "\n\n"
            query = prefix_prompt.format(data["text"][0])
            inputs_token = tokenizer(query, return_tensors="pt").to(model.device)
            input_ids = inputs_token.input_ids
            print("input token length: {}".format(len(input_ids[0])))

            if "vicuna" in args.model_path:
                if len(input_ids[0]) > 15000 or len(input_ids[0]) < 1000:
                    continue

            if "vicuna" in args.model_path:
                if "rerope" in args.method:
                    if len(input_ids[0]) > 12000:
                        continue
                # elif "weave" in args.model_path:
                #     if len(input_ids[0]) > 16000:
                #         continue

            if "llama2-7b" in args.model_path:
                if len(input_ids[0]) > 16000 or len(input_ids[0]) < 1000:
                    continue

            if "mpt-7b" in args.model_path:
                if len(input_ids[0]) > 16000 or len(input_ids[0]) < 1000:
                    continue


            if args.method == "lm-infinite":
                output = model.model(
                    input_ids=input_ids,
                    use_cache=False,
                )
            elif args.method == "streaming-llm":
                output = model.forward_pll(input_ids=input_ids,)
            else:
                output = model(
                    input_ids=input_ids,
                    use_cache=True,
                )
            # outputs = model.generate(input_ids, **model_custom_config)


            logits = output["logits"]
            batch_size, length, _ = logits.shape
            token_nll = F.cross_entropy(
                logits[:, :-1].reshape(batch_size * (length - 1), -1),
                input_ids[:, 1:].reshape(-1),
                reduction="none"
            ).reshape(batch_size, -1)
            token_nll_list = [
                _nll[:_mask.sum() - 1]
                for _nll, _mask in zip(token_nll, inputs_token.attention_mask)
            ]
            output["token_nll_list"] = token_nll_list

        all_nll.extend([_nll_list.cpu().numpy() for _nll_list in output["token_nll_list"]])
        print("Shape:", output["logits"].shape)
        print("Start: ", output["token_nll_list"][0][:20])
        print("End:", output["token_nll_list"][0][-20:])

        nll_stats_sequence = defaultdict(list)
        nll_stats_token = defaultdict(list)
        for nll_seq in all_nll:
            # 该长度下的，平均值
            nll_stats_sequence[len(nll_seq)].append(nll_seq.mean())
            # 该长度下，记录每个位置的值
            for token_i, token_nll in enumerate(nll_seq):
                nll_stats_token[token_i].append(token_nll)

        nll_stats_sequence = {
            length: {"mean": np.nanmean(np.array(record)),
                     "var": np.nanvar(np.array(record))}
            for length, record in nll_stats_sequence.items()
        }
        nll_stats_token = {
            length: {"mean": np.nanmean(np.array(record)),
                     "var": np.nanvar(np.array(record))}
            for length, record in nll_stats_token.items()
        }

        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/full_{args.save_file}", "wb") as f:
            pickle.dump({
                "nll_stats_sequence": nll_stats_sequence,
                "nll_stats_token": nll_stats_token,
                "all_nll": all_nll
            }, f)

        with open(f"{os.path.join(os.getcwd(), args.log_dir)}/{args.save_file}", "wb") as f:
            pickle.dump({
                "nll_stats_sequence": nll_stats_sequence,
                "nll_stats_token": nll_stats_token,
            }, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/data/persist/models/llama-3b") # /data/persist/models/vicuna/vicuna-13b-v1.5
    parser.add_argument("--method", type=str, default="old")
    parser.add_argument("--dataset", type=str, default="pile") # "../datas/download_data/pile-deduplicated/train-00000-of-01650-f70471ee3deb09c0.parquet"
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--log_dir", type=str, default="../logs")
    parser.add_argument("--save_file", type=str, default="stats_test_.pkl")
    parser.add_argument("--cuda", type=str, default="1")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    set_seed(args.seed)
    main(args)








