## FILE TO GRADE DIVERSITY OF OUTPUTS USING SELF-BLEU
import argparse
import pickle
import os
from pathlib import Path

from llama.metrics import *
from llama.tokenizer import Tokenizer as LlamaTokenizer

if __name__ == "__main__":
    # read args
    parser = argparse.ArgumentParser()
    parser.add_argument("file_path", type=str)
    parser.add_argument("prompt_len")
    parser.add_argument("eval_type")
    args = parser.parse_args()
    file_path = args.file_path
    prompt_len = int(args.prompt_len)
    eval_type = args.eval_type
    file_name, file_type = os.path.splitext(file_path)
    
    # load file
    with open(file_path, "rb") as f:
        r = pickle.load(f)
    
    # evaluate or calculate diversity per ablation
    if eval_type == "eval":
        n_prompts, n_drafts, seq_len = r.shape
        if n_drafts > 1:
            gens = r[:, :, prompt_len:] # cut prompts
            diversity = calculate_diversity(gens.tolist())
            print(f"Diversity: {diversity}")
        for i in range(n_drafts):    
            u, b, t = calculate_ngram_repetition(r[:, i, :].reshape(n_prompts, -1)[:, prompt_len:].tolist())
            print(f"Unigram Repeat: {u}  Bigram Repeat: {b}  Trigram Repeat: {t}  Avg: {(u + b + t) / 3}")
    elif eval_type == "ablation":
        results = {}
        for param in r:
            print(f"Parameter: {param}")
            tr = r[param]
            n_prompts, n_drafts, seq_len = tr.shape
            tr = tr[:, :, prompt_len:] # cut prompts
            diversity = calculate_diversity(tr.tolist())
            print(f"Diversity: {diversity}")
            results[param] = diversity
        print("Saving...")
        with open(file_name + "_div." + file_type, "wb") as f:
            pickle.dump(results, f) 
        print("Saved")
    else:
        print("Invalid evaluation type")
