
import json
import os.path
import pdb

from sacrebleu.metrics import BLEU, CHRF, TER


def bleu_de2en(len, lr, apc):
    fila_main_path = f'/home/username/Experiments/LLM_ensemble/Eval/1219-mt/de-en-mistral_epochs5'
    sys_result_path = f'mt.ensemble_lr{lr}_anchor_point_count{apc}_learning_epochs_nums{len}.jsonl'
    ref_file_path = '/home/username/Experiments/LLM_ensemble/Datasets/Flores/sampled_100/eng_Latn.sampled.devtest'

    sys_file_path = os.path.join(fila_main_path, sys_result_path)
    sys_file = open(sys_file_path, 'r', encoding='utf-8')
    ref_file = open(ref_file_path, 'r', encoding='utf-8')
    sys_lines = sys_file.readlines()
    ref_lines = ref_file.readlines()
    sys_list = [json.loads(line)['model_answer'].split(f'English:')[-1].strip(f'German: </s>  \n') for line in
                sys_lines]
    ref_list = [[line.strip() for line in ref_lines]]
    bleu = BLEU()
    bleuscore = bleu.corpus_score(sys_list, ref_list)
    bleu.get_signature()
    with open(os.path.join(fila_main_path, 'BLEU_score.jsonl'), 'a+', encoding='utf-8') as result_file:
        dict = {}
        dict['sys_file_path'] = sys_file_path
        dict['learning_rate'] = lr
        dict['anchor_point_count'] = apc
        dict['learning_epochs_nums'] = len
        dict['bleu'] = float(str(bleuscore).split(' ')[2])
        result_file.write(json.dumps(dict, ensure_ascii=False) + '\n')
    return bleuscore


def bleu_en2de(len, lr, apc):
    fila_main_path = f'/home/username/Experiments/LLM_ensemble/Eval/1219-mt/en-de-mistral_epochs5'
    sys_result_path = f'mt.ensemble_lr{lr}_anchor_point_count{apc}_learning_epochs_nums{len}.jsonl'
    ref_file_path = '/home/username/Experiments/LLM_ensemble/Datasets/Flores/sampled_100/deu_Latn.sampled.devtest'

    sys_file_path = os.path.join(fila_main_path, sys_result_path)
    sys_file = open(sys_file_path, 'r', encoding='utf-8')
    ref_file = open(ref_file_path, 'r', encoding='utf-8')
    sys_lines = sys_file.readlines()
    ref_lines = ref_file.readlines()
    sys_list = [json.loads(line)['model_answer'].split(f'German:')[-1].strip(f'English: </s>  \n') for line in
                sys_lines]
    ref_list = [[line.strip() for line in ref_lines]]
    bleu = BLEU()
    bleuscore = bleu.corpus_score(sys_list, ref_list)
    bleu.get_signature()
    with open(os.path.join(fila_main_path, 'BLEU_score.jsonl'), 'a+', encoding='utf-8') as result_file:
        dict = {}
        dict['sys_file_path'] = sys_file_path
        dict['learning_rate'] = lr
        dict['anchor_point_count'] = apc
        dict['learning_epochs_nums'] = len
        dict['bleu'] = float(str(bleuscore).split(' ')[2])
        result_file.write(json.dumps(dict, ensure_ascii=False) + '\n')
    return bleuscore


len = 5
apclist = [500, 1000, 2000, 3000, 4000, 5000, 8000, 16000]
# apclist = [4000]

for apc in apclist:
    for lr in [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]:
        try:
            bleu_score = bleu_de2en(len, lr, apc)
            print(bleu_score)
        except:
            continue

for apc in apclist:
    for lr in [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]:
        try:
            bleu_score = bleu_en2de(len, lr, apc)
            print(bleu_score)
        except:
            continue

# import json
# import os.path
# from sacrebleu.metrics import BLEU, CHRF, TER
#
# ref_file_path = '/home/username/Experiments/LLM_ensemble/Datasets/Flores/sampled_100/eng_Latn.sampled.devtest'
# # ref_file_path = '/home/username/Experiments/LLM_ensemble/Datasets/Flores/sampled_100/deu_Latn.sampled.devtest'
# # sys_file_path = '/home/username/Experiments/LLM_ensemble/Eval/1217-mt/en-de-llama-base/mt.ensemble_lr0.0_anchor_point_count1000_learning_epochs_nums1.jsonl'
# # sys_file_path = '/home/username/Experiments/LLM_ensemble/mt.ensemble_lr0.0_anchor_point_count1000_learning_epochs_nums1.jsonl'
# sys_file_path = '/home/username/Experiments/LLM_ensemble/Eval/1218-mt/de-en-llama_epochs5/mt.ensemble_lr1.0_anchor_point_count4000_learning_epochs_nums5.jsonl'
# sys_file = open(sys_file_path, 'r', encoding='utf-8')
# ref_file = open(ref_file_path, 'r', encoding='utf-8')
# sys_lines = sys_file.readlines()
# ref_lines = ref_file.readlines()
# sys_list = [json.loads(line)['model_answer'].split(f'English:')[-1].strip(f'German: </s>  \n') for line in sys_lines]
# # sys_list = [json.loads(line)['mt_result'] for line in sys_lines]
# # sys_list = [json.loads(line)['model_answer'].split(f'German:')[-1].strip(f'English: </s>  \n') for line in sys_lines]
# ref_list = [[line.strip() for line in ref_lines]]
# bleu = BLEU()
# bleuscore = bleu.corpus_score(sys_list, ref_list)
# bleu.get_signature()
# print(bleuscore)
