import json

import numpy as np
import torch

from .make_tree import Parse_Tree
from tqdm import tqdm


class TreeBuilder:
    def __init__(self, model, query_loader, hook, prompt_file, categories=None):
        """
        Initializes a new instance of the TreeBuilder class.

        Args:
            model (torch.nn.Module): The pre-trained PyTorch model to use for predictions.
            query_loader (torch.utils.data.DataLoader): The data loader for the query set.
            hook (ActivationHook): The activation hook object to capture the intermediate activations.
            categories (list of str, optional): The list of categories. Defaults to None.
        """
        self.query_loader = query_loader
        self.prompt_file = prompt_file
        self.attributes = []
        self.labels = []
        self.hook = hook
        self.prompt_dict = self.load_prompt()
        
        if categories is None:
            self.categories = self.query_loader.dataset.categories
        else:
            self.categories = categories
        self.model = model
        self.trees = [None for _ in self.categories]

    def extract_feature(self, device):
        """
        Extracts the intermediate activations from the pre-trained model.

        Args:
            device (str): The device on which to run the model (e.g. "cpu" or "cuda").
        """
        print("="*10 + "Extract Features from support set" + "="*10)
        self.model.eval()
        self.hook.reset()
        self.hook.enable()
        with torch.no_grad():
            for batch in tqdm(self.query_loader):
                images, labels, attribute = batch
                self.attributes += attribute
                self.labels.append(labels)
                images = images.to(device)
                _ = self.model(images)
        self.hook.disable()

    def load_prompt(self):
        """
        Loads the prompt dictionary from a JSON file.

        Returns:
            dict: The prompt dictionary.
        """
        with open(self.prompt_file) as file:
            prompt_dict = json.load(file)
        return prompt_dict

    def construct_tree(self):
        """
        Constructs a Parse_Tree for each category based on the 
        prompt dictionary and the intermediate activations.
        """

        print("="*10 + "Constructs a Parse_Tree for each category" + "="*10)
        # Get the mapping between category names and their corresponding indices in the dataset
        category_to_index = self.query_loader.dataset.category_to_index

        # Initialize the list of Parse_Trees
        

        # Iterate over the prompt dictionary and construct a Parse_Tree for each category
        for k, v in self.prompt_dict.items():
            assert isinstance(v, dict), "prompt_dict must be dict"
            c_tree = Parse_Tree.from_dict(v)
            self.trees[category_to_index[k]] = c_tree

        # Concatenate the feature tensors of all batches
        all_features = torch.cat(self.hook.features)

        # Retrieve attributes and labels from the model
        attribute = np.array(self.attributes)
        labels = torch.cat(self.labels).detach().cpu().numpy()

        # Iterate over each category and assign the feature values to the corresponding Parse_Tree object
        for i, c in enumerate(self.categories):
            # Create a mask for the current category
            c_mask = labels == i

            # Retrieve the attributes and feature tensor for the current category
            c_attr = attribute[c_mask]
            c_feat = all_features[c_mask]

            # Create a dictionary to store the feature values for each attribute
            values = dict()
            for attr in set(c_attr):
                # Retrieve the feature values for the current attribute
                values[attr] = c_feat[c_attr == attr]

            # Assign the feature values to the corresponding Parse_Tree object
            self.trees[category_to_index[c]].set_values(values)

    def query_tree_save(self, test_loader, device, output_file):
        """
        Uses the pre-trained model to make predictions on the test set and 
        saves the corresponding path dicts to a JSON file.

        Args:
            test_loader (torch.utils.data.DataLoader): The data loader for the test set.
            device (str): The device on which to run the model (e.g. "cpu" or "cuda").
            output_file (str): The path to the output JSON file.
        """
        # Set the model to evaluation mode
        self.model.eval()

        # Reset the activation hook
        self.hook.reset()
        self.hook.enable()

        # Initialize a list to store all path_dicts
        path_dicts = []

        with torch.no_grad():
            # Iterate over each batch in the test loader
            for inputs, targets in tqdm(test_loader):
                # Move the inputs and targets to the device
                inputs = inputs.to(device)
                targets = targets.to(device)

                # Reset the activation hook
                self.hook.reset()

                # Make predictions using the model
                outputs = self.model(inputs)
                _, predicted = torch.max(outputs.data, 1)

                # Iterate over each prediction and corresponding feature vector
                for p, feat, target in zip(predicted, self.hook.features[0], targets):
                    # Find the path in the Parse_Tree for the current feature vector
                    path = self.trees[p.item()].top_matches(feat)
                    path_dict = path.to_dict()

                    # Append the path_dict and target to the list
                    path_dicts.append({'path_dict': path_dict, 'pred': p.item()})
        self.hook.disable()
        # Save the path_dicts list to a json file
        with open(output_file, "w") as f:
            json.dump(path_dicts, f)

    def query_tree_label(self, test_loader, device, k=1):
        """
        Uses the pre-trained model to make predictions on the test set and 
        saves the corresponding path dicts to a JSON file.

        Args:
            test_loader (torch.utils.data.DataLoader): The data loader for the test set.
            device (str): The device on which to run the model (e.g. "cpu" or "cuda").

        Returns:
            all_paths (torch.Tensor): A tensor containing all the path dictionaries
                for the test set, with shape (num_samples, max_depth).
        """

        print("="*10 + "Creating Hierarchical Labels for the Dataset" + "="*10)

        # Set the model to evaluation mode
        self.model.eval()

        # Reset the activation hook
        self.hook.reset()
        self.hook.enable()

        # Initialize a list to store all path dictionaries
        all_paths = []

        # Find the maximum depth of all trees
        max_depth = max([t.max_depth() for t in self.trees])

        with torch.no_grad():
            # Iterate over each batch in the test loader
            for inputs, targets in tqdm(test_loader):
                # Move the inputs and targets to the device
                inputs = inputs.to(device)
                targets = targets.to(device)

                # Reset the activation hook
                self.hook.reset()

                # Make predictions using the model
                outputs = self.model(inputs)
                _, predicted = torch.max(outputs.data, 1)

                # Iterate over each prediction and corresponding feature vector
                for p, feat, target in zip(predicted, self.hook.features[0], targets):
                    # Find the path in the Parse_Tree for the current feature vector
                    paths = self.trees[target.item()].top_matches_nomerge(feat, k=k)
                    new_paths = []
                    # Pad the path with zeros up to the maximum depth
                    for path in paths:
                        path[0] = target.item()
                        cpath = path + [0] * int(max_depth-len(path))
                        
                        new_paths.append(cpath)
                    new_paths = torch.tensor(new_paths)[:k]
                    if new_paths.shape[0] != 1:
                        print(target.item())
                        print(self.trees[target.item()])
                        print(new_paths.shape)
                        print(self.trees[target.item()].top_matches(feat, k=k))
                        exit()
                    all_paths.append(new_paths)

        self.hook.disable()

        # Convert the list of path dictionaries to a tensor
        all_paths = torch.stack(all_paths)

        return all_paths
    
    
    
                    
                    
                    
                    
                    



