import torch
torch.set_printoptions(profile="full")
import json
import numpy as np
from tqdm import tqdm
import sys
from copy import deepcopy
import argparse


class Node:
    def __init__(self, children=None):
        self.children = children if children is not None else []
        self.num_nodes_in_tree = 1 + sum(c.num_nodes_in_tree for c in self.children)
        self.depth = 1 + max([0] + [c.depth for c in self.children])

    def __str__(self, depth=1):
        if self.children:
            delim = '\n' + ''.join(['        '] * depth)
            children_str = ': ' + delim + delim.join(f"[{c.__str__(depth=depth+1)}]" for c in self.children)
        else:
            children_str = ''
        return f'[{(self.num_nodes_in_tree, self.depth)}{children_str}]'


def sequoia_tree_search(P, N, D, B=None):
    if B is not None:
        assert B <= P.shape[0] - 1
    else:
        B = P.shape[0] - 1

    T = np.zeros((N + 1, D + 1, B + 1))
    T_max = np.zeros((N + 1, D + 1))
    T[:,:,:] = -float('inf')
    T_max[:,:] = -float('inf')
    T[1, 1:, 0] = 1.0
    T_max[1, 1:] = 1.0
    branch_map = {(1, d, 0): [] for d in range(1, D + 1)}

    # best_new_node[n, d, b] = A pointer to the best node (tree root node) to add
    #     as the b^th child of the tree root with budget n, depth <= d, and b children.
    # best_tree[n, d] = A pointer to the best node (tree root node) with n nodes and depth <= d.
    best_new_node = {(1, d, 0): None for d in range(1, D + 1)}
    best_tree = {(1, d): Node() for d in range(1, D + 1)}

    for n in tqdm(range(2, N + 1)):
        for d in range(2, D + 1):
            for b in range(1, B + 1):
                x = np.nan_to_num(T[n - 1: 0: -1, d, b - 1] + P[b] * T_max[1: n, d - 1], nan=0.0, neginf=-float('inf'))
                T[n, d, b] = np.max(x)
                if T[n, d, b] > 0.0:
                    argmax = np.argmax(x)
                    best_new_node[n, d, b] = best_tree[argmax + 1, d - 1]
            T_max[n, d] = np.max(T[n, d, :])

            if T_max[n, d] > 0:
                best_b = np.argmax(T[n, d, :])
                best_n_budget_depth_d_tree_children = []
                remaining_budget = n
                branch_map[n, d, best_b] = []
                for b in range(best_b, 0, -1):
                    try:
                        next_child = best_new_node[remaining_budget, d, b]
                    except:
                        import pdb; pdb.set_trace()
                    best_n_budget_depth_d_tree_children.insert(0, next_child)
                    remaining_budget -= next_child.num_nodes_in_tree
                    branch_map[n, d, best_b].insert(
                        0, 
                        (next_child.num_nodes_in_tree, d - 1, len(next_child.children))
                    )
                assert remaining_budget == 1, f'{n=}, {d=}, {best_b=}, {remaining_budget=}'
                best_tree[n, d] = Node(children=best_n_budget_depth_d_tree_children)

    return T, best_tree, branch_map


parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default="demo-config.json", help='config')
args = parser.parse_args()
print(args)
with open(args.config, 'r') as f:
    config = json.load(f)
p = torch.load(config["acceptance_rate_vector"]).cpu()[:-1].numpy() + 1e-5
max_branch = p.shape[0] - 1

max_depth = config["max_depth"]

max_budget = config["max_budget"]

# This is workaround for issue that occurs when p has entries that are exactly equal to zero.
p[1:] += 0.00000001

T, best_tree, branch_map = sequoia_tree_search(p, N=max_budget, D=max_depth, B=max_branch)
T = torch.tensor(T)

results = T.max(dim=2).values
print(results)
draft_inference_time = config['draft_time']
target_verify_time = config['target_time']


valid_budget = config['valid_budget']

dec_time = torch.inf
pairs = None
for i, b in enumerate(valid_budget):
    target_time = target_verify_time[i]
    for d, ac_len in enumerate(results[b]):
        if ac_len < 0:
            continue
        x = ((d) * draft_inference_time + target_time) / ac_len
        if x < dec_time:
            dec_time = x
            pairs = (b,d)

print(dec_time, target_verify_time[0] / dec_time, pairs)

(m, l) = pairs
b = T[m][l].argmax(dim=0).item()

positions = [0]
states = [(m,l,b)]
active = [True]
depth = [0]
Successors = [[]]
attention_mask = torch.zeros(m,m).long()
parents = [-1]
expand_lists = []
expand_branches = []
num_nodes = 1
while True:

    expand = []
    expand_branch = []
    for i, act in enumerate(active):
        if act: 
            if parents[i] != -1:
                attention_mask[i] = attention_mask[parents[i]]
            attention_mask[i][i] = 1
            expand.append(i)
            active[i] = False
            (x,y,z) = states[i]
            expand_branch.append(z)
            positions.extend(list(range(num_nodes, num_nodes + z)))
            Successors[i].extend(list(range(num_nodes, num_nodes + z)))
            Successors.extend([[] for _ in range(z)])
            parents.extend([i for _ in range(z)])
            depth.extend([depth[i] + 1 for _ in range(z)])
            states.extend(branch_map[(x,y,z)])
            assert len(branch_map[(x,y,z)]) == z
            num_nodes = num_nodes + z
    if len(expand) == 0:
        break
    expand_lists.append(expand)
    expand_branches.append(expand_branch)
    active.extend([True for _ in range(sum(expand_branch))])


assert num_nodes == m
assert len(positions) == m
assert len(depth) == m
grow_map = {
    "roots": expand_lists,
    "branches": expand_branches,
    "Successors":Successors,
    "mask": attention_mask,
    "depth": torch.LongTensor(depth),
    "size": num_nodes
}

path = config['dst']

torch.save(grow_map, path)