import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel, BertForSequenceClassification
import gym
import numpy as np
from sklearn.metrics import f1_score

# 自定义环境
class PruningEnv(gym.Env):
    def __init__(self, json_texts, questions, labels, evidences, prompt_loader, model, dataset_name):
        super(PruningEnv, self).__init__()
        self.json_texts = json_texts
        self.questions = questions
        self.labels = labels
        self.evidences = evidences
        self.prompt_loader = prompt_loader
        self.model = model
        self.dataset_name = dataset_name
        self.current_index = 0
        self.action_space = gym.spaces.MultiBinary(len(self.json_texts[0]))
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(len(self.json_texts[0]),), dtype=np.float32)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # def reset(self):
    #     self.current_index = 0
    #     return self.json_texts[self.current_index]
    def reset(self):
        self.current_index = 0
        return self._get_state()

    def step(self, action):
        pruned_text = self.prune_text(self.json_texts[self.current_index], action)
        full_prompts = self.prompt_loader.prompt_construction(self.questions[self.current_index], 
                                                              self.evidences[self.current_index], 
                                                              pruned_text, self.dataset_name)
        messages = [{"role": "user", "content": full_prompts}]
        predicted_answer = self.model.predict(messages)
        # predicted_answer = self.model.predict(pruned_text, self.questions[self.current_index])
        reward = f1_score([self.labels[self.current_index]], [predicted_answer])
        self.current_index += 1
        done = self.current_index >= len(self.json_texts)
        next_state = self._get_state() if not done else None
        # return self.json_texts[self.current_index] if not done else None, reward, done, {}
        return next_state, reward, done, {}

    def prune_text(self, text, action):
        pruned_text = {k: v for i, (k, v) in enumerate(text.items()) if action[i] == 1}
        return pruned_text
    
    def _get_state(self):
        json_text = self.json_texts[self.current_index]
        question = self.questions[self.current_index]
        state_text, cls_positions  = self._convert_to_input_format(json_text, question)
        inputs = self.tokenizer.encode_plus(state_text, return_tensors="pt", padding=True, truncation=True)
        return inputs["input_ids"], inputs["attention_mask"], len(json_text), cls_positions 

    def _convert_to_input_format(self, json_text, question):
        kv_pairs = " ".join([f"[CLS] {key} : {value} [SEP]" for key, value in json_text.items()])
        state_text = f"{kv_pairs} [CLS] {question} [SEP]"
        cls_positions = [i for i in range(len(state_text.split())) if state_text.split()[i] == "[CLS]"]
        return state_text, cls_positions
    
# Define the policy model
class BertPruningModel(nn.Module):
    def __init__(self):
        super(BertPruningModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask, cls_positions):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        # logits = self.classifier(outputs.pooler_output)
        # cls_embeddings = outputs.last_hidden_state[:, 0, :]  # 假设所有 [CLS] 的嵌入在 0, 4, 8, ... 位置上
        # cls_positions = [i for i in range(outputs.last_hidden_state.size(1)) if (i % 4) == 0]
        # cls_embeddings = outputs.last_hidden_state[:, cls_positions, :]
        cls_embeddings = outputs.last_hidden_state[:, cls_positions, :]
        logits = self.classifier(cls_embeddings)
        return logits
    
# Define the PPO training process
class PPOAgent:
    def __init__(self, env, policy_model, optimizer, clip_param=0.2, ppo_epochs=10, mini_batch_size=32, gamma=0.99):
        self.env = env
        self.policy_model = policy_model
        self.optimizer = optimizer
        self.clip_param = clip_param
        self.ppo_epochs = ppo_epochs
        self.mini_batch_size = mini_batch_size
        self.gamma = gamma

    def select_action(self, state):
        self.policy_model.eval()
        with torch.no_grad():
            # state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            input_ids, attention_mask, num_keywords, cls_positions = state
            logits = self.policy_model(input_ids, attention_mask, cls_positions)
            logits = self.policy_model(state)
            action_probs = torch.sigmoid(logits)
            action = (action_probs > 0.5).float()
        # return action.squeeze(0).numpy()
        return action.squeeze(0).numpy()[:num_keywords]

    def compute_returns(self, rewards, dones, next_value, gamma):
        returns = []
        R = next_value
        for step in reversed(range(len(rewards))):
            R = rewards[step] + gamma * R * (1 - dones[step])
            returns.insert(0, R)
        return returns

    def ppo_update(self, states, actions, log_probs, returns, advantages):
        for _ in range(self.ppo_epochs):
            for i in range(0, len(states), self.mini_batch_size):
                idx = slice(i, i + self.mini_batch_size)
                # batch_states = states[idx]
                batch_states = [states[j] for j in range(i, min(i + self.mini_batch_size, len(states)))]
                batch_input_ids = torch.cat([state[0] for state in batch_states], dim=0)
                batch_attention_mask = torch.cat([state[1] for state in batch_states], dim=0)
                batch_cls_positions = [state[3] for state in batch_states]
                batch_cls_positions = torch.tensor(batch_cls_positions)
                batch_actions = actions[idx]
                batch_log_probs = log_probs[idx]
                batch_returns = returns[idx]
                batch_advantages = advantages[idx]

                # logits = self.policy_model(batch_states)
                logits = self.policy_model(batch_input_ids, batch_attention_mask)
                new_log_probs = -((batch_actions - torch.sigmoid(logits)) ** 2).sum(dim=1)
                ratio = (new_log_probs - batch_log_probs).exp()

                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch_advantages
                loss = -torch.min(surr1, surr2).mean()

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

    def train(self, num_episodes):
        for episode in range(num_episodes):
            state = self.env.reset()
            episode_rewards = []
            episode_log_probs = []
            episode_states = []
            episode_actions = []
            episode_dones = []

            for t in range(len(self.env.json_texts)):
                action = self.select_action(state)
                next_state, reward, done, _ = self.env.step(action)

                # logits = self.policy_model(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
                input_ids, attention_mask, _, cls_positions = state
                logits = self.policy_model(input_ids, attention_mask, cls_positions)
                log_prob = -((torch.tensor(action, dtype=torch.float32) - torch.sigmoid(logits)) ** 2).sum(dim=1)

                episode_rewards.append(reward)
                episode_log_probs.append(log_prob)
                episode_states.append(state)
                episode_actions.append(action)
                episode_dones.append(done)

                state = next_state

                if done:
                    break

            next_value = 0  # Assuming a terminal state with zero value
            returns = self.compute_returns(episode_rewards, episode_dones, next_value, self.gamma)
            returns = torch.tensor(returns, dtype=torch.float32)
            log_probs = torch.cat(episode_log_probs)
            states = torch.tensor(episode_states, dtype=torch.float32)
            actions = torch.tensor(episode_actions, dtype=torch.float32)
            advantages = returns - returns.mean()

            self.ppo_update(states, actions, log_probs, returns, advantages)
            print(f"Episode {episode+1} completed.") 


class InferenceModel:
    def __init__(self, pruning_model, llm):
        self.pruning_model = pruning_model
        self.llm = llm
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def _convert_to_input_format(self, json_text, question):
        kv_pairs = " ".join([f"[CLS] {key} : {value} [SEP]" for key, value in json_text.items()])
        state_text = f"{kv_pairs} [CLS] {question} [SEP]"
        cls_positions = [i for i in range(len(state_text.split())) if state_text.split()[i] == "[CLS]"]
        return state_text, cls_positions

    def prune_text(self, json_text, question):
        # Encode json_text and question into BERT input format
        inputs_state, cls_positions = self._convert_to_input_format(json_text, question)
        # inputs = self.tokenizer.encode_plus(json_text, question, return_tensors="pt") ##
        inputs = self.tokenizer.encode_plus(inputs_state, return_tensors="pt", padding=True, truncation=True)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]


        # Predict whether each relation will be retained or not
        with torch.no_grad():
            logits = self.pruning_model(input_ids, attention_mask, cls_positions)
            probs = torch.sigmoid(logits)
            predictions = (probs > 0.5).long()

        # Prune json_text based on prediction results
        pruned_text = {k: v for i, (k, v) in enumerate(json_text.items()) if predictions[0, i] == 1}
        return pruned_text

    # def predict(self, json_text, question):
    #     # prune json_text
    #     pruned_text = self.prune_text(json_text, question)

    #     # Input the pruned text into the black box large language model for inference
    #     answer = self.llm.predict(pruned_text, question)
    #     return answer




if __name__=='__main__':
    print(123)
    
