import os
import pdb
import sys
import time
import json
import logging
from tqdm import tqdm
import argparse

root_dir = '/data/home/username/Experiments/LLM_ensemble'
sys.path.insert(0, root_dir)

from src.StopAtSpecificTokenCriteria import StopAtSpecificTokenCriteria
from transformers import LogitsProcessorList, StoppingCriteriaList
from src.logits_processor.early_stop_logits_processor import EarlyStopLogitsProcessor
from src.post_processing.answer_extract import answer_extract
from src.instruction_generate import demon_prompt_generate, task_instruction_generate
from src.model_load import load_model


def main():
    start_time = time.time()  

    
    parser = argparse.ArgumentParser(description='Process some files.')
    
    parser.add_argument('--config', help='the name of the file to process')
    parser.add_argument('--learning_rate', '-lr', default=0.0, type=float, required=False, help="learning_rate")
    parser.add_argument('--anchor_point_count', '-apc', default=32000, type=int, required=False,
                        help='anchor_point_count')
    parser.add_argument('--learning_epochs_nums', '-len', default=5, type=int, required=False,
                        help='learning_epochs_nums')
    parser.add_argument('--result_save_dir', '-rsd', default="./", type=str, required=False, help='result_save_dir')
    parser.add_argument('--run_mode', '-rm', default="dev", type=str, required=False, help='result_save_dir')
    parser.add_argument('--logits_processor_mode', '-lpm', default="based_on_probility_transfer_logits_processor",
                        type=str,
                        required=False,
                        help='logits_processor_mode')
    parser.add_argument('--device_compute', '-dp', default="cuda:1", type=str, required=False,
                        help='device_compute')
    parser.add_argument('--device0', '-d0', default="cuda:0", type=str, required=False,
                        help='device0')
    parser.add_argument('--device1', '-d1', default="cuda:0", type=str, required=False,
                        help='device1')
    parser.add_argument('--device2', '-d2', default="cuda:0", type=str, required=False,
                        help='device2')
    parser.add_argument('--device3', '-d3', default="cuda:0", type=str, required=False,
                        help='device3')

    parser.add_argument('--main_temperature', '-mt', default=100, type=float, required=False,
                        help='main_temperature')
    parser.add_argument('--assist_temperature', '-at', default=100, type=float, required=False,
                        help='assist_temperature')
    parser.add_argument('--min_prob', default=0.8, type=float, required=False,
                        help='min_prob')
    parser.add_argument('--max_prob', default=0.9, type=float, required=False,
                        help='max_prob')

    # 解析命令行参数
    args = parser.parse_args()

    # 使用指定的文件名来操作文件
    with open(args.config, 'r', encoding='utf-8') as f:
        config_json = json.load(f)

    main_model_path = config_json["model_path"]["main_model_path"]

    dev_file_path = config_json["file_path"]["dev_file_path"]
    test_file_path = config_json["file_path"]["test_file_path"]

    demon_file_path = config_json["file_path"]["demon_file_path"]

    instruction = config_json["prompt_template"]["instruction"]
    instruction_parameter = config_json["prompt_template"]["instruction_parameter"]
    main_model_system_template = config_json["prompt_template"]["main_model_system_template"]
    max_new_tokens = config_json["run_parameter"]["max_new_tokens"]

    end_index = config_json["run_parameter"]["end_index"]
    try:
        end_token_id = config_json["run_parameter"]["end_token_id"]
    except:
        end_token_id = 2

    demon_parameter = config_json["prompt_template"]["demon_parameter"]

    result_process_parameter = config_json["result_process_parameter"]
    try:
        early_stop_string_list = result_process_parameter["early_stop_string_list"]
    except:
        early_stop_string_list = None
    result_save_dir = args.result_save_dir
    logits_processor_mode = args.logits_processor_mode
    if os.path.isdir(result_save_dir):
        pass
    else:
        os.makedirs(result_save_dir)

    anchor_point_count = args.anchor_point_count
    learning_rate = args.learning_rate
    learning_epochs_nums = args.learning_epochs_nums
    run_mode = args.run_mode

    device0 = args.device0

    input_file_path = dev_file_path if run_mode == "dev" else test_file_path

    logging.basicConfig(filename=os.path.join(result_save_dir,
                                              f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.process.log'),
                        level=logging.DEBUG)
    logging.info(f'\n【config_json:】{config_json}')
    logging.info(f'\n【result_save_dir:】{result_save_dir}')
    logging.info(f'\n【anchor_point_count:】{anchor_point_count}')
    logging.info(f'\n【learning_rate:】{learning_rate}')
    logging.info(f'\n【learning_epochs_nums:】{learning_epochs_nums}')

    main_model, main_model_tokenizer, main_model_streamer = load_model(main_model_path, "auto")

    # =============================================================================================================
    result_file_path = os.path.join(result_save_dir,
                                    f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.jsonl')
    try:
        with open(result_file_path, 'r') as file:
            lines = file.readlines()
            line_count = len(lines)
        start_index = line_count
    except:
        start_index = 0
    with open(input_file_path, 'r', encoding='utf-8') as input_file:
        try:
            demon_instruction, demon_count = demon_prompt_generate(demon_file_path, demon_parameter)
        except:
            demon_instruction = ""
            demon_count = 0
        contents = input_file.readlines()
        end = end_index
        if end > len(contents):
            end = len(contents)
        for index, line in enumerate(tqdm(contents[start_index:])):
            line = json.loads(line)

            task_instruction = task_instruction_generate(line, instruction_parameter)
            final_input_prompt = instruction + demon_instruction + task_instruction
            main_model_input = main_model_system_template.format(final_input_prompt)

            information_key_list = demon_parameter['key']
            information_dict = {}
            for key in information_key_list:
                information_dict[key] = line[key]
            information_dict['main_model_input'] = main_model_input
            information_dict['demon_count'] = demon_count
            information_dict['task_instruction'] = task_instruction
            information_dict['max_new_tokens'] = max_new_tokens
            information_dict['result_process_parameter'] = result_process_parameter
            information_dict['logits_processor_mode'] = logits_processor_mode
            information_dict['forced_eos_token_id'] = end_token_id

            stopping_criteria = StoppingCriteriaList()
            if early_stop_string_list is not None:

                if early_stop_string_list is not None:
                    for early_stop_string in early_stop_string_list:
                        early_stop_ids = main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                              add_special_tokens=False).input_ids.tolist()[0][1:]
                        stopping_criteria.append(StopAtSpecificTokenCriteria(token_ids_list=early_stop_ids))

            main_model_input_ids = main_model_tokenizer(main_model_input, return_tensors="pt",
                                                        add_special_tokens=False).input_ids.to(device0)
            generation_kwargs = {
                "input_ids": main_model_input_ids,
                "max_new_tokens": max_new_tokens,
                "do_sample": False,
                "num_beams": 1,
                "eos_token_id": main_model_tokenizer.eos_token_id,
                "bos_token_id": main_model_tokenizer.bos_token_id
            }

            # generate_ids = main_model.generate(**generation_kwargs, pad_token_id=main_model_tokenizer.eos_token_id,
            #                                    logits_processor=main_model_logits_processor_list,
            #                                    streamer=main_model_streamer)

            generate_ids = main_model.generate(**generation_kwargs, stopping_criteria=stopping_criteria,
                                               streamer=main_model_streamer)

            text = main_model_tokenizer.decode(generate_ids[0], skip_special_tokens=False)

            result_process_parameter = information_dict['result_process_parameter']
            split_key_before_list = result_process_parameter["split_key_before"]
            split_key_behind_list = result_process_parameter["split_key_behind"]

            model_answer, prediction = answer_extract(text, information_dict['demon_count'], split_key_before_list,
                                                      split_key_behind_list)

            model_answer_dict = {'answer': information_dict['answer'],
                                 'prediction': prediction.strip(), 'main_model_input': main_model_input, 'all': text,
                                 'model_answer': model_answer,
                                 'question': information_dict['question']}

            result_file_path = os.path.join(result_save_dir,
                                            f'ensemble_lr{learning_rate}_anchor_point_count{anchor_point_count}_learning_epochs_nums{learning_epochs_nums}.jsonl')
            with open(result_file_path, 'a+', encoding='utf-8') as result_file:
                result_file.write(json.dumps(model_answer_dict, ensure_ascii=False) + '\n')

    time_elapsed = time.time() - start_time  # 获得时间差
    minutes = int(time_elapsed / 60)
    seconds = int(time_elapsed % 60)
    print('Time taken: {} min {} sec'.format(minutes, seconds))


if __name__ == '__main__':
    main()
