from tqdm import tqdm
import time
import random
import re
import ast
import os
from dotenv import load_dotenv

from concurrent.futures import ThreadPoolExecutor
from openai import OpenAI
from openai._exceptions import RateLimitError, BadRequestError
from httpx import Timeout

from mix_eval.prompts.judge_prompts import (
    gpt_judge_for_openended_turn1,
    gpt_judge_for_openended_turn2,
    gpt_judge_for_openended_turn3,
    gpt_judge_for_openended_turn1_withref
    )

########################ChatGPT########################
class ChatGPTJudgeOpenend:
    def __init__(self, args):
        self.JUDGE = args.judge
        self.FIX_INTERVAL_SECOND = 0
        self.MAX_RETRY_NUM = 99
        self.MAX_NEW_TOKENS = 999
        
        self.args = args

        load_dotenv()
        self.client = OpenAI(
            api_key=os.getenv('k_oai'),
            timeout=Timeout(timeout=180.0, connect=5.0)
        )
    
    def _GPT_decode(self, inputs):
        completion = self.client.chat.completions.create(
                            model=self.JUDGE,
                            response_format={ "type": 'text'},
                            max_tokens=self.MAX_NEW_TOKENS,
                            messages=inputs,
                            )
        time.sleep(self.FIX_INTERVAL_SECOND)
        return completion

    def GPT_decode(self, inputs):
        delay = 1
        blocked = 0
        for i in range(self.MAX_RETRY_NUM):
            try:
                completion = self._GPT_decode(inputs)
                return completion
            except RateLimitError as e:
                exponential_base = 2
                delay *= exponential_base * (1 + random.random())
                print(f"RateLimitError, retrying after {round(delay, 2)} seconds, {i+1}-th retry...")
                print(e)
                time.sleep(delay)
                continue
            except BadRequestError as e:
                blocked += 1
                if blocked >= 10:
                    print("Blocked too many times, skipping...")
                    return 'Blocked'
                print(f"Input is blocked, retrying...")
                print(e)
                time.sleep(1)
                continue
            except Exception as e:
                print(f"Error in GPT_decode, retrying...")
                print(e)
                time.sleep(1)
                continue
        print(f"Failed after {self.MAX_RETRY_NUM} retries.")
        return 'Error'

    def get_score_from_judge(self, judge_response):
        """
        Get the score from the judge response.
        """
        one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]")
        one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]")
        
        match = re.search(one_score_pattern, judge_response)
        if not match:
            match = re.search(one_score_pattern_backup, judge_response)

        if match:
            rating = ast.literal_eval(match.groups()[0])
        else:
            rating = -1
            
        return float(rating)
    
    def annotate_p_turn(self, inputs):
        completion = self.GPT_decode(inputs)
        if completion == 'Error':
            print(f"Error in GPT_decode, the turn {inputs} will be retried later...")
            return None
        elif completion == 'Blocked':
            print(f"{inputs}: \n\nBlocked, the entry treated as bad entry.")
            return '[[1.0]]'
        annotation = completion.choices[0].message.content
        return annotation

    def annotate_p(self, task):    
        turns = task['turns']
        responses = task['response']

        if not len(turns) == len(responses) == 3:
            print(f"Invalid task, the turns and responses should be of length 3.")
            return None
        
        inputs_1 = gpt_judge_for_openended_turn1(turns[0], responses[0])
        inputs_2 = gpt_judge_for_openended_turn2(turns[0], responses[0], turns[1], responses[1])
        inputs_3 = gpt_judge_for_openended_turn3(turns[0], responses[0], turns[1], responses[1], turns[2], responses[2])
        
        judge_response = []
        judge_score = []
        for turn_id, inputs in enumerate([inputs_1, inputs_2, inputs_3]):
            MAX_RETRY_NUM_TURN = 25
            success_get_score = False
            for _ in range(MAX_RETRY_NUM_TURN):
                annotation = self.annotate_p_turn(inputs)
                if annotation is not None and 1<=self.get_score_from_judge(annotation)<=10:
                    judge_response.append(annotation)
                    judge_score.append(self.get_score_from_judge(annotation))
                    success_get_score = True
                    break
                else:
                    print(f"Invalid judgement, retrying the turn {turn_id} judgement...")
                    print(f"{annotation}")
            
            if not success_get_score:
                # randomly assign a score
                score = random.randint(1, 10)
                print(f"Max retry number {MAX_RETRY_NUM_TURN} reached, while some tasks are still not judged. Randomly assigned score. Turn {turn_id} Rating: [[{score}]]")
                judge_response.append(f"Randomly assigned score. Rating: [[{score}]]")
                judge_score.append(score)
        
        if not len(judge_response) == len(judge_score) == 3:
            print(f"Invalid judgement, the judge_score should be of length 3.")
            return None
        task['judge_response'] = judge_response
        task['judge_score'] = judge_score
        return task
        

    def annotate_parallel(self, tasks):
        print(f"Judging in parallel, in total {self.args.api_parallel_num} threads.")
        results = []
        with ThreadPoolExecutor(self.args.api_parallel_num) as executor:
            for entry in tqdm(
                executor.map(self.annotate_p, tasks), total=len(tasks)
            ):
                results.append(entry)
        if None in results:
            raise ValueError("Some tasks are not judged correctly due to errors in annotate_p. Please inspect and retry.")
        return results
    
    
class ChatGPTJudgeOpenendwithRef:
    def __init__(self, args):
        self.JUDGE = args.judge
        self.FIX_INTERVAL_SECOND = 0
        self.MAX_RETRY_NUM = 99
        self.MAX_NEW_TOKENS = 4096
        self.NUM_ROUNDS = 2
        
        self.args = args

        load_dotenv()
        self.client = OpenAI(
            api_key=os.getenv('k_oai'),
            timeout=Timeout(timeout=180.0, connect=5.0)
        )
    
    def _GPT_decode(self, inputs):
        completion = self.client.chat.completions.create(
                            model=self.JUDGE,
                            response_format={ "type": 'text'},
                            max_tokens=self.MAX_NEW_TOKENS,
                            messages=inputs,
                            )
        time.sleep(self.FIX_INTERVAL_SECOND)
        return completion

    def GPT_decode(self, inputs):
        delay = 1
        blocked = 0
        for i in range(self.MAX_RETRY_NUM):
            try:
                completion = self._GPT_decode(inputs)
                return completion
            except RateLimitError as e:
                exponential_base = 2
                delay *= exponential_base * (1 + random.random())
                print(f"RateLimitError, retrying after {round(delay, 2)} seconds, {i+1}-th retry...")
                print(e)
                time.sleep(delay)
                continue
            except BadRequestError as e:
                blocked += 1
                if blocked >= 10:
                    print("Blocked too many times, skipping...")
                    return 'Blocked'
                print(f"Input is blocked, retrying...")
                print(e)
                time.sleep(1)
                continue
            except Exception as e:
                print(f"Error in GPT_decode, retrying...")
                print(e)
                time.sleep(1)
                continue
        print(f"Failed after {self.MAX_RETRY_NUM} retries.")
        return 'Error'

    def get_score_from_judge(self, judge_response):
        """
        Get the score from the judge response.
        """
        pattern = re.compile("\[\[([AB<>=]+)\]\]")
        pattern_backup = re.compile("\[([AB<>=]+)\]")
        
        match = re.search(pattern, judge_response)
        if not match:
            match = re.search(pattern_backup, judge_response)

        if match:
            rating = match.groups()[0]
        else:
            rating = None
            
        return rating
    
    def annotate_p_turn(self, inputs):
        completion = self.GPT_decode(inputs)
        if completion == 'Error':
            print(f"Error in GPT_decode, the turn {inputs} will be retried later...")
            return None
        elif completion == 'Blocked':
            print(f"{inputs}: \n\nBlocked, the entry treated as bad entry.")
            return 'Blocked. The judgement: [[A><B]]'
        annotation = completion.choices[0].message.content
        return annotation

    def annotate_p(self, task_pair):    
        task, task_ref = task_pair
        turns, turns_ref = task['turns'], task_ref['turns']
        responses, responses_ref = task['response'], task_ref['response']

        if not len(turns) == len(responses) == len(turns_ref) == len(responses_ref) == 1:
            print(f"Invalid task, the turns and responses should be of length 1.")
            return None

        if not turns == turns_ref:
            print(f"Invalid task and task_ref, the questions of "
                  "the candidate and reference models should be the same.")
            return None
        
        rounds = []
        for r_id in range(self.NUM_ROUNDS):
            round_result = {
                'judge_responses': [],
                'judge_scores': []
            }
            if r_id % 2 == 1:
                responses, responses_ref = responses_ref, responses
            
            inputs_1 = gpt_judge_for_openended_turn1_withref(
                turns[0], 
                responses_ref[0],
                responses[0],
                )
                
            for turn_id, inputs in enumerate([inputs_1]):
                MAX_RETRY_NUM_TURN = 25
                success_get_vote = False
                judgement = ""
                for _ in range(MAX_RETRY_NUM_TURN):
                    annotation = self.annotate_p_turn(inputs)
                    inputs.append({"role": "assistant", "content": annotation})
                    judgement += ('\n' + annotation)
                    if annotation is not None and self.get_score_from_judge(judgement) is not None:
                        round_result['judge_responses'].append(judgement)
                        round_result['judge_scores'].append(
                            self.get_score_from_judge(judgement)
                            )
                        success_get_vote = True
                        break
                    else:
                        inputs.append({"role": "user", "content": "continue your judgment and finish by outputting a final verdict label with the above-mentioned format"})
                        print(f"Invalid judgement, retrying the turn {turn_id} judgement...")
                        # print(f"{annotation}")
                
                if not success_get_vote:
                    # randomly assign a score
                    print(f"Max retry number {MAX_RETRY_NUM_TURN} reached, "
                          "while this task is still not judged. Marked as fail.")
                    
                    round_result['judge_responses'].append(f"Fail to judge.")
                    round_result['judge_scores'].append(f"Fail to judge.")
            
            if not len(round_result['judge_responses']) == len(round_result['judge_scores']) == 1:
                print(f"Invalid judgement, the judge_responses and judge_scores should be of length 1.")
                return None
            
            rounds.append(round_result)
            
        task['judge_rounds'] = rounds
        
        return task
        

    def annotate_parallel(self, task_pairs):
        print(f"Judging in parallel, in total {self.args.api_parallel_num} threads.")
        results = []
        with ThreadPoolExecutor(self.args.api_parallel_num) as executor:
            for entry in tqdm(
                executor.map(self.annotate_p, task_pairs), total=len(task_pairs)
            ):
                results.append(entry)
        if None in results:
            raise ValueError("Some tasks are not judged correctly due to "
                             "errors in annotate_p. Please inspect and retry.")
        return results
