import torch 
import torch.nn as nn 
import numpy as np 

from var_red_gfn.utils import Environment 

@torch.no_grad()
def np_pruning(Q, pi, trees, sites):
    vocab_size = len(pi) 
    max_recursion_steps = np.floor(np.log2(trees.num_nodes) + 1)
    max_recursion_steps = int(max_recursion_steps.item())

    likelihoods = torch.zeros((trees.batch_size, trees.num_nodes, sites.shape[0], vocab_size), device=pi.device)

    # Compute the likelihoods for the nodes
    # Since the nodes are appended sequentially to the tree,
    # this ensures that the computation of a node's likelihood will be preceded
    # by the computation of its children's likelihoods
    for idx in range(vocab_size):
        likelihoods[trees.batch_ids, :trees.num_leaves, :, idx] = (sites == idx).t().to(dtype=likelihoods.dtype)

    for idx in range(trees.num_leaves, trees.num_nodes):
        # `idx`'s children
        idx_children = trees.children[trees.batch_ids, idx]
        left_children, right_children = idx_children[:, 0].long(), idx_children[:, 1].long()
        left_likelihoods = likelihoods[trees.batch_ids, left_children]
        right_likelihoods = likelihoods[trees.batch_ids, right_children]

        left_transition = torch.matrix_exp(trees.branch_length.view(-1, 1, 1) * Q[None, ...])
        right_transition = torch.matrix_exp(trees.branch_length.view(-1, 1, 1) * Q[None, ...])

        marginal_left = torch.bmm(
            left_transition, torch.transpose(left_likelihoods, 1, 2)
        )
        marginal_right = torch.bmm(
            right_transition, torch.transpose(right_likelihoods, 1, 2)
        )

        likelihoods[trees.batch_ids, idx] = torch.transpose(marginal_left * marginal_right, 1, 2)

    marginal_likelihoods = likelihoods[trees.batch_ids, trees.root] @ pi
    return torch.log(marginal_likelihoods).sum(dim=-1)

@torch.no_grad() 
def compute_site_likelihood(Q, vocab_size, trees, site):
    likelihoods = torch.zeros((trees.batch_size, trees.num_nodes, vocab_size), device=trees.device)

    for idx in range(vocab_size):
        likelihoods[trees.batch_ids, :trees.num_leaves, idx] = (site == idx).t().to(dtype=likelihoods.dtype)

    for idx in range(trees.num_leaves, trees.num_nodes):
        # `idx`'s children
        idx_children = trees.children[trees.batch_ids, idx]
        left_children, right_children = idx_children[:, 0].long(), idx_children[:, 1].long()
        left_likelihoods = likelihoods[trees.batch_ids, left_children]
        right_likelihoods = likelihoods[trees.batch_ids, right_children]

        left_transition = torch.matrix_exp(trees.branch_length.view(-1, 1, 1) * Q[None, ...])
        right_transition = torch.matrix_exp(trees.branch_length.view(-1, 1, 1) * Q[None, ...])

        marginal_left = torch.bmm(
            left_transition, left_likelihoods.unsqueeze(-1) 
        ).squeeze(dim=-1) 
        marginal_right = torch.bmm(
            right_transition, right_likelihoods.unsqueeze(-1)
        ).squeeze(dim=-1) 

        likelihoods[trees.batch_ids, idx] = marginal_left * marginal_right

    return likelihoods 

class LogReward(nn.Module):

    def __init__(self, pi=None, sites=None, sub_rate_matrix=None, temperature=1.):
        super(LogReward, self).__init__()
        self.pi = pi
        self.data = sites 
        self.vocab_size = len(pi)  
        self.temperature = temperature 
        self.Q = sub_rate_matrix 
        
        self.shift = 0. # For numerical stability 

    @torch.no_grad()
    def forward(self, trees):
        loglikelihood = np_pruning(self.Q, self.pi, trees, self.data)
        # loglikelihood = torch.log(likelihood).sum(dim=1)
        return (loglikelihood - self.shift) / self.temperature 
    
class Trees(Environment):

    def __init__(self, num_leaves, batch_size, log_reward=None, default_branch_length=1, device='cpu'):
        super(Trees, self).__init__(batch_size, 2*num_leaves-1, log_reward, device=device) 
        self.num_leaves = num_leaves
        self.num_nodes = 2 * self.num_leaves - 1
        self.num_internal_nodes = self.num_leaves - 1
        self.max_trajectory_length = self.num_leaves - 1

        # Children
        self.children = torch.zeros((self.batch_size, self.num_nodes, 4), device=self.device)
        self.parents = torch.zeros((self.batch_size, self.num_nodes), device=self.device)

        # The index of the root
        self.root = self.num_nodes - 1

        # Attributes for the generative process

        # Dynamically changing actions
        self.actions = torch.triu_indices(self.num_leaves, self.num_leaves, offset=1).t().expand(self.batch_size, -1, 2)
        self.actions = self.actions.to(self.device) 
        self.forward_mask = torch.ones((self.actions.shape[0], self.actions.shape[1]), device=self.device)

        # The nodes' features
        self.X = torch.zeros((self.num_nodes, self.num_leaves + 2), device=self.device)
        self.X[
            :self.num_leaves, :self.num_leaves
        ] = torch.eye(self.num_leaves, device=self.device)

        self.X[
            self.num_leaves:-1, self.num_leaves
        ] = 1.

        self.X[
            -1, self.num_leaves + 1
        ] = 1.

        # The ID of the nextly included node
        self.next_node = self.num_leaves

        # A default value for the branches
        self.default_branch_length = default_branch_length 
        self.branch_length = torch.ones((self.batch_size,), device=self.device) * self.default_branch_length 
        
        # An edge list for using GNNs
        self.edge_list = torch.zeros((0, 3), dtype=torch.long, device=self.device)

        self.num_parents = torch.zeros((self.batch_size,), device=self.device)

        # A backward action corresponds to removing a non-leaf node
        # At each action, one must choose a fatherless node and remove it
        # and the corresponding edges within the graph
        self.backward_mask = torch.zeros((self.batch_size, self.num_nodes), device=self.device)
        self.backward_mask[:, torch.arange(self.num_leaves, device=self.device)] = 1.

        self.distances = None 

    @torch.no_grad()
    def apply(self, indices):
        actions = self.actions[self.batch_ids, indices]

        left, right = actions[:, 0], actions[:, 1]

        # This requires patience, thoughtfulness and a whole lot of iterative considerations
        # Update the actions
        self.actions = torch.where(self.actions == left.view(-1, 1, 1), self.next_node, self.actions)
        self.forward_mask = torch.where((self.actions == right.view(-1, 1, 1)).any(dim=-1), 0., self.forward_mask)

        # Update the states
        self.children[self.batch_ids, self.next_node] = torch.vstack([left, right, self.branch_length, self.branch_length]).t()
        self.parents[self.batch_ids, left] = self.next_node
        self.parents[self.batch_ids, right] = self.next_node

        # Update the list of edges
        next_node_vec = torch.ones((self.batch_size,), dtype=torch.long, device=self.device) * self.next_node
        edges_to_append = torch.vstack([
            torch.vstack([self.batch_ids, next_node_vec, left]).t(),
            torch.vstack([self.batch_ids, next_node_vec, right]).t()
        ])
        self.edge_list = torch.vstack([self.edge_list, edges_to_append.long()])

        # Update the label of the next node
        self.stopped += (self.next_node == self.root) # The last action is deterministic

        up_backward_mask = self.backward_mask.clone() 
        up_backward_mask[self.batch_ids, left] = 0.
        up_backward_mask[self.batch_ids, right] = 0.
        up_backward_mask[:, self.next_node] = 1.
        self.backward_mask = up_backward_mask.clone() 
        
        self.next_node = self.next_node + 1

        self.num_parents += ((left < self.num_leaves) & (right < self.num_leaves)).long()
        self.num_parents -= ((left >= self.num_leaves) & (right >= self.num_leaves)).long()
        self.is_initial = (self.backward_mask[:, :self.num_leaves] == 1.).all(dim=1).long()

        return (self.stopped < 2.)

    @torch.no_grad()
    def backward(self, indices):
        # What must be removed: the edges in the edge's list
        # What must be updated: the feasible actions, the number of parents, the children, the parents, the forward mask

        nodes_to_remove = indices + self.num_leaves

        left_children = self.children[self.batch_ids, nodes_to_remove, 0]
        right_children = self.children[self.batch_ids, nodes_to_remove, 1]

        # Update the (forward) mask, update the list of actions, update the (backward) mask
        batch_ids, indices, left_or_right = torch.argwhere(self.actions == right_children.view(-1, 1, 1)).t()
        self.forward_mask[batch_ids, indices] = \
            (self.backward_mask[batch_ids, self.actions[batch_ids, indices, 1 - left_or_right]] == 1.).to(dtype=self.forward_mask.dtype)

        batch_ids, indices, left_or_right = torch.argwhere(self.actions == nodes_to_remove.view(-1, 1, 1)).t()
        self.actions[batch_ids, indices, left_or_right] = left_children[batch_ids].long()

        self.backward_mask[self.batch_ids, nodes_to_remove] = 0.
        self.backward_mask[self.batch_ids, left_children.long()] = 1.
        self.backward_mask[self.batch_ids, right_children.long()] = 1.
    
        # Update the edge list
        condition_values1 = self.batch_ids.unsqueeze(1)
        condition_values2 = nodes_to_remove.unsqueeze(1)
        condition_mask = torch.all(
            torch.eq(self.edge_list[:, None, :2],
                    torch.hstack([condition_values1, condition_values2])),
        dim=-1).any(dim=-1)
        self.edge_list = self.edge_list[~condition_mask]
        assert condition_mask.sum() == 2 * self.batch_size, (self.edge_list, condition_mask.sum())

        # Update children
        self.children[self.batch_ids, nodes_to_remove] = 0

        # Update the parents
        self.parents[self.batch_ids, left_children.long()] = 0.
        self.parents[self.batch_ids, right_children.long()] = 0.

        self.num_parents = self.backward_mask[:, self.num_leaves:].sum(dim=1)
        self.is_initial = ((self.backward_mask[:, :self.num_leaves] == 1.).all(dim=1) & \
                        (self.backward_mask[:, self.num_leaves:] == 0.).all(dim=1)).long()

        batch_ids, forward_actions = torch.argwhere(
            (self.actions == left_children.view(-1, 1, 1)).any(dim=-1) & \
                (self.actions == right_children.view(-1, 1, 1)).any(dim=-1)
        ).t()
        assert len(batch_ids) == self.batch_size
        return forward_actions

    @property
    @torch.no_grad()
    def internal_nodes(self):
        return torch.arange(self.num_internal_nodes, device=self.device) + self.num_leaves

    @property
    @torch.no_grad()
    def leaves(self):
        return torch.arange(self.num_leaves, device=self.device)

    @torch.no_grad()
    def is_leaf(self, nodes):
        return (nodes < self.num_leaves)

    @torch.no_grad()
    def edge_list_t(self):
        return ((self.edge_list[:, 0] * self.num_nodes).view(-1, 1).expand(-1, 2) + self.edge_list[:, 1:]).t().long()

    @property
    @torch.no_grad()
    def expanded_data(self):
        return self.X.reshape(1, *self.X.shape).expand(self.batch_size, self.X.shape[0], self.X.shape[1]).reshape(-1, self.X.shape[1])

    @property
    def adjacency_matrix(self):
        adj = torch.zeros((self.batch_size, self.num_nodes, self.num_nodes), device=self.device)
        batch_ids, left_node, right_node = self.edge_list.t()
        adj[batch_ids, left_node, right_node] = 1.
        adj[batch_ids, right_node, left_node] = 1.
        return adj

    @torch.no_grad()
    def merge(self, batch_state):
        # batch size, batch ids, actions, branch length, edge list, stopped and num parents, mask
        self.actions = torch.vstack([self.actions, batch_state.actions])
        self.branch_length = torch.hstack([self.branch_length, batch_state.branch_length])
        # Update the non-incumbent tree's edge list 
        e = batch_state.edge_list.clone() 
        e[:, 0] += self.batch_size 
        self.edge_list = torch.vstack([self.edge_list, e])
        self.num_parents = torch.hstack([self.num_parents, batch_state.num_parents])
        self.forward_mask = torch.vstack([self.forward_mask, batch_state.forward_mask])
        super().merge(batch_state) 

    @torch.no_grad()
    def floyd_warshall_batch(self):
        dist = self.adjacency_matrix
        dist = torch.where(dist == 0., torch.inf, dist)
        ids = torch.arange(self.num_nodes)
        dist[:, ids, ids] = 0

        for k in range(self.num_nodes):
            pdist = dist[:, :, k, None] + dist[:, None, k, :]
            mask = (dist > pdist).to(dtype=torch.long)
            dist = torch.where(mask == 1., pdist, dist)

        leaves_root = torch.hstack([self.leaves.long(), torch.tensor(self.root, dtype=torch.long, device=self.device)])
        return torch.hstack([
            dist[:, leaves_root, i] for i in leaves_root
        ])

    @property 
    def unique_input(self): 
        if self.distances is not None:  
            return self.distances 
        else: 
            self.distances = self.floyd_warshall_batch() 
            return self.distances 
    
    @staticmethod 
    def sample_from_phylogeny(tree, sub_rate_matrix, num_sites, background_freq, device='cpu'):
        # Sample something from the background frequencies for the tree's root
        sites = torch.zeros((num_sites, tree.num_nodes), dtype=torch.long, device=device)
        sites[:, -1] = torch.multinomial(background_freq,
                            num_samples=num_sites, replacement=True).flatten()
        # Use Junkes & Cantor (1969)'s substitution model to generate the data set
        def mutate(node):
            parent = tree.parents[0, node].long()
            parent_sites = sites[:, parent]

            child_idx = torch.argwhere(tree.children[0, parent, :2] == node).item()
            branch_length = tree.children[0, parent, child_idx + 2]
            transition_matrix = torch.matrix_exp(sub_rate_matrix * branch_length)
            sites[:, node] = torch.multinomial(transition_matrix[parent_sites],
                            num_samples=1, replacement=True).squeeze()

            left_child, right_child = tree.children[0, node, 0].long(), tree.children[0, node, 1].long()

            if node >= tree.num_leaves:
                mutate(left_child)
                mutate(right_child)

        # This should be very cautiously manuvered to ensure the algorithm's adequacy
        left_child, right_child = tree.children[0, tree.root, 0].long(), tree.children[0, tree.root, 1].long()
        mutate(left_child)
        mutate(right_child)
        return sites
