import pdb
import re
import json
import transformers
import tqdm
import multiprocessing
import os
import argparse

DEFAULT_PAD_TOKEN = "<pad>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "<unk>"
def load_tokenizer(model_name_or_path):


    print(f"+ [Model] Initializing Tokenizer: {model_name_or_path}")
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_name_or_path,
        padding_side="right",
        use_fast=False,
    )

    if 'phi' in model_name_or_path:
        tokenizer.pad_token = tokenizer.unk_token
    else:
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({
                "eos_token": DEFAULT_EOS_TOKEN,
                "bos_token": DEFAULT_BOS_TOKEN,
                "unk_token": DEFAULT_UNK_TOKEN,
            })


    return tokenizer

def process_line(tokenizer, lines, wf_name):
    acc = []  # Will store tuples of (label, prediction) for each expression
    recall_count = [0, 0]  # [number of correct positives, number of actual positives]
    hallucination = []
    with open(wf_name, 'w', encoding='utf-8') as wf:
        for line in tqdm.tqdm(lines):
            for output in line['outputs']:
                v_scores = output.get('vscores', [])
                response = output.get('response', "")
                is_true = output.get('label', "")

                evaluation_results, process_v_scores = evaluate_expression_para(response, v_scores, tokenizer, is_true)
                output['process_vscores'] = process_v_scores

                if evaluation_results['hallucination']:
                    hallucination.append(1)
                else:
                    hallucination.append(0)

                # Add the results to the accuracy list for each expression
                for label, prediction in zip(evaluation_results['label'],evaluation_results['prediction']):
                    acc.append((label, prediction))

                # Update recall counts for each expression
                for idx, prediction in enumerate(evaluation_results['prediction']):
                    label = evaluation_results['label'][idx]
                    if label == 'positive':
                        recall_count[1] += 1  # Increment the count of actual positives
                        if prediction == 'positive':
                            recall_count[0] += 1  # Increment the count of correct positives
            wf.writelines(json.dumps(line, ensure_ascii=False) + '\n' )

    accuracy = sum(1 for label, prediction in acc if label == prediction) / len(acc) if acc else 0
    recall = recall_count[0] / recall_count[1] if recall_count[1] > 0 else 0

    print(f"accuracy:{accuracy} for total {len(acc)}")
    print(f"recall :{recall} for total {recall_count[1]}")
    print(f"hallucination :{sum(hallucination)/len(hallucination)} for total {len(hallucination)}")
    return accuracy, recall

def locate_sublist(lst, sublst):
    for i in range(len(lst) - len(sublst) + 1):
        if lst[i:i+len(sublst)] == sublst:
            return i  # Return the starting index of the sublist in the list
    assert ('not right')


def split_string_list(a_list, number ='\n'):
    sublists = []
    current_sublist = []
    for item in a_list:
        current_sublist.append(item)
        if item == number:
            if current_sublist:  # if the current sublist is not empty
                sublists.append(''.join(current_sublist))
                current_sublist = []  # start a new sublist

    # Don't forget to add the last sublist if it's not empty
    if current_sublist:
        sublists.append(''.join(current_sublist))

    return sublists
def split_token_list(a_list, number =13):
    sublists = []
    current_sublist = []
    for item in a_list:
        current_sublist.append(item)
        if item == number:
            if current_sublist:  # if the current sublist is not empty
                sublists.append(current_sublist)
                current_sublist = []  # start a new sublist

    # Don't forget to add the last sublist if it's not empty
    if current_sublist:
        sublists.append(current_sublist)

    return sublists
# Modify evaluate_expression function to return a list of results


def evaluate_expression_para(response_all, v_score, tokenizer, is_true):
    # Initialize lists to hold multiple evaluation results for each expression
    # here we make the v_score label in a "first error detection"
    labels = []
    predictions = []
    sol_tokens = tokenizer(response_all).input_ids
    process_v_score = [0] * len(sol_tokens)
    hallucination = False
    error_detection = False
    response_list = split_string_list(response_all)
    token_list = split_token_list(sol_tokens)
    previous_len = 0
    for idx, string in enumerate(response_list):
        # match = re.search(r'<<(.+?)>>', string)
        para_token = token_list[idx]
        para_token_location =  sum([len(item) for item in token_list[:idx]])

        if error_detection:
            break


        if abs(v_score[para_token_location]) < 1e-5:
            error_detection = True

        elif  (v_score[para_token_location + len(para_token) - 1] - v_score[para_token_location])/v_score[para_token_location] < -0.5:
            error_detection = True

        else:
            if not error_detection:
                process_v_score[para_token_location : para_token_location + len(para_token) ] = [1] * len(para_token)


        previous_len += len(string)

    return {'label': labels, 'prediction': predictions, 'hallucination': hallucination }, process_v_score





def evaluate_expression_pch(response_all, v_score, tokenizer, is_true):
    labels = []
    predictions = []
    sol_tokens = tokenizer(response_all).input_ids
    process_v_score = [0] * len(sol_tokens)
    hallucination = False
    error_detection = False
    response_list = split_string_list(response_all)
    token_list = split_token_list(sol_tokens)
    previous_len = 0
    for idx, string in enumerate(response_list):
        match = re.search(r'<<(.+?)>>', string)
        para_token = token_list[idx]
        para_token_location =  sum([len(item) for item in token_list[:idx]])
        if match:
            expression = match.group(1)
            start_token = tokenizer(response_all[ : previous_len + match.span()[0]]).input_ids
            if sol_tokens[:len(start_token)] != start_token:
                start_token = start_token[:-1]
            seg_token_location = len(start_token)
            seq_token =  tokenizer(response_all[: previous_len + match.span()[1]]).input_ids[len(start_token):]
            try:
                if abs(v_score[seg_token_location]) < 1e-5:
                    prediction = 'negative'  # there is a extra example in v_score
                    error_detection = True

                elif (v_score[min(seg_token_location + len(seq_token), len(v_score) - 1)] - v_score[seg_token_location]) / (v_score[seg_token_location]) < -0.9:
                    prediction = 'negative'  # there is a negative change in v_score
                    error_detection = True
                else:
                    prediction = 'positive'  # no negative change in v_score
                    if not error_detection:
                        process_v_score[para_token_location : para_token_location + len(para_token)] =  [1] * len(para_token)
            except:
                import pdb
                pdb.set_trace()
            try:
                before_equal, after_equal = expression.split('=')
                computed_value = float(eval(before_equal.strip()))
                actual_value = float(after_equal.strip().replace(",", ""))
                # Use the positive v_score change as a proxy for a correct evaluation
                if abs(computed_value - actual_value) <= 1e-3:
                    label = 'positive'
                else:
                    label = 'negative'
                    hallucination = True

                labels.append(label)
                predictions.append(prediction)
            except Exception as e:
                pass
        else:
            if not error_detection:
                process_v_score[para_token_location: para_token_location + len(para_token)] = [1] * len(para_token)




        previous_len += len(string)

    return {'label': labels, 'prediction': predictions, 'hallucination': hallucination}, process_v_score





def process_chunk(tokenizer, chunk, wf_path):
    acc = []  # Will store tuples of (label, prediction) for each expression
    recall_count = [0, 0]  # [number of correct positives, number of actual positives]
    hallucination = []

    with open(wf_path, 'w', encoding='utf-8') as wf:
        for line in tqdm.tqdm(chunk):
            for output in line['outputs']:
                v_scores = output.get('vscores', [])
                response = output.get('response', "")
                is_true = output.get('label', "")
                evaluation_results, process_v_scores = evaluate_expression_para(response, v_scores, tokenizer, is_true)
                output['process_vscores'] = process_v_scores

                if evaluation_results['hallucination']:
                    hallucination.append(1)
                else:
                    hallucination.append(0)


                # Add the results to the accuracy list for each expression
                for label, prediction in zip(evaluation_results['label'], evaluation_results['prediction']):
                    acc.append((label, prediction))

                # Update recall counts for each expression
                for idx, prediction in enumerate(evaluation_results['prediction']):
                    label = evaluation_results['label'][idx]
                    if label == 'positive':
                        recall_count[1] += 1  # Increment the count of actual positives
                        if prediction == 'positive':
                            recall_count[0] += 1  # Increment the count of correct positives
            wf.writelines(json.dumps(line, ensure_ascii=False) + '\n')

        # Return the metrics and counts, not just the rates, to allow aggregation
    return {
        "accuracy_sum": sum(1 for label, prediction in acc if label == prediction),
        "total": len(acc),
        "recall_correct": recall_count[0],
        "recall_total": recall_count[1],
        "hallucination_sum": sum(hallucination),
        "hallucination_total": len(hallucination),
    }



def parallel_process_line(tokenizer, lines, wf_path, num_processes=32):
    if num_processes is None:
        num_processes = multiprocessing.cpu_count()

    # Split lines into chunks
    chunk_size = int(len(lines) / num_processes)
    chunks = [lines[i:i + chunk_size] for i in range(0, len(lines), chunk_size)]

    # Generate a unique temporary file path for each chunk
    temp_files = [f"multirun/{wf_path}_temp_{i}.json" for i in range(len(chunks))]

    # Create a pool of workers to process data in parallel
    with multiprocessing.Pool(processes=num_processes) as pool:
        # Map each chunk to process_chunk function along with a unique temporary file path
        results = pool.starmap(process_chunk, [(tokenizer, chunk, temp_file) for chunk, temp_file in zip(chunks, temp_files)])

    # Combine results from temporary files into the final output file
    with open(f"multirun2/{wf_path}.json", 'w', encoding='utf-8') as wf:
        for temp_file in temp_files:
            with open(temp_file, 'r', encoding='utf-8') as tf:
                wf.write(tf.read())
            os.remove(temp_file)  # Clean up temporary file

    # Aggregate metrics from all chunks
    total_acc = sum(result['accuracy_sum'] for result in results)
    total = sum(result['total'] for result in results)
    total_recall_correct = sum(result['recall_correct'] for result in results)
    total_recall = sum(result['recall_total'] for result in results)
    total_hallucination = sum(result['hallucination_sum'] for result in results)
    total_hallucination_counts = sum(result['hallucination_total'] for result in results)

    # Calculate overall metrics
    overall_accuracy = total_acc / total if total else 0
    overall_recall = total_recall_correct / total_recall if total_recall else 0
    overall_hallucination = total_hallucination / total_hallucination_counts if total_hallucination_counts else 0

    print(f"Overall accuracy: {overall_accuracy}")
    print(f"Overall recall: {overall_recall}")
    print(f"Overall hallucination: {overall_hallucination}")




def main():
    parser = argparse.ArgumentParser(description="Process JSONL file and perform deduplication.")
    parser.add_argument("file_path", type=str, help="Path to the JSONL file")
    parser.add_argument("model_path", type=str, help="Path to the JSONL file")
    args = parser.parse_args()
    line = [json.loads(line) for line in open(args.file_path, 'r', encoding = 'utf-8').readlines()]
    for ex in line:
        dedup_outputs = []
        for output in ex['outputs']:
            if len(output['tokens']) > 2048:
                continue
            dedup_outputs.append(output)
        ex['outputs'] = dedup_outputs
    tokenizer = load_tokenizer(args.model_path)
    parallel_process_line(tokenizer, line, "test.json")

    
if __name__ == "__main__":
    main()
