import argparse
import os
import pdb
import sys
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM

root_dir = '/data/home/username/Experiments/LLM_ensemble'
sys.path.insert(0, root_dir)

from src.nllb.Demon_prompt_generator import demon_prompt_generator
from src.nllb.Model_generator import mistral_translate

if __name__ == '__main__':
    device = "cuda:0"

    parser = argparse.ArgumentParser(description='Process some files.')
    
    parser.add_argument('--src_lang', default="eng_Latn", type=str, help='the name of the file to process')
    parser.add_argument('--tgt_lang', default="ron_Latn", type=str, help='the name of the file to process')
    parser.add_argument('--src_lang_full', default="English", type=str, help='the name of the file to process')
    parser.add_argument('--tgt_lang_full', default="Romanian", type=str, help='the name of the file to process')
    parser.add_argument('--mode', default="dev", type=str, help='the name of the file to process')
    parser.add_argument('--learning_rate', default=0, type=float, help='the name of the file to process')
    args = parser.parse_args()
    print(args)
    src_lang = args.src_lang
    tgt_lang = args.tgt_lang
    src_lang_full = args.src_lang_full
    tgt_lang_full = args.tgt_lang_full
    mode = args.mode
    learning_rate = args.learning_rate

    translate_direction = src_lang + "-" + tgt_lang + "-4shot"

    LLM_model_path = "/data3/username/ModelsHub/Llama-2-13b-hf"
    LLM_tokenizer = AutoTokenizer.from_pretrained(LLM_model_path)
    LLM_model = AutoModelForCausalLM.from_pretrained(LLM_model_path, torch_dtype="auto").to(device)
    LLM_model.eval()

    task_instruction = f"Translate the sentence from {src_lang_full} to {tgt_lang_full}:"

    demon_prompt = demon_prompt_generator(translate_direction)

    if mode == "dev":
        input_file_path = f"/data/home/username/Experiments/LLM_ensemble/Datasets/Flores/{mode}/{src_lang}.dev"
    else:
        input_file_path = f"/data/home/username/Experiments/LLM_ensemble/Datasets/Flores/dev{mode}/{src_lang}.devtest"

    output_file_path = f"/data/home/username/Experiments/LLM_ensemble/Eval/Flores-{src_lang}-{tgt_lang}/v4-LLaMA-2-13b-{src_lang}-{tgt_lang}-{mode}/{tgt_lang}_{learning_rate}.txt"
    if not os.path.exists(os.path.dirname(output_file_path)):
        os.makedirs(os.path.dirname(output_file_path))

    with open(input_file_path, 'r', encoding="utf-8") as src_file:
        src_contents = src_file.readlines()

        for line in tqdm(src_contents):
            nllb_input_text = line.strip()

            llm_input_text = task_instruction + demon_prompt + f"\n{src_lang_full}:" + nllb_input_text + f"\n{tgt_lang_full}:"

            result = mistral_translate(LLM_model, LLM_tokenizer, llm_input_text, 200, device)
            with open(output_file_path, "a+", encoding="utf-8") as f_result:
                f_result.write(result + "\n")
