import collections
import numpy as onp
import jax.numpy as jnp


def chain_sum(length, num_actions, num_targets):
    num_preds = num_targets * length
    mat = onp.zeros((num_targets + num_preds, num_targets + num_preds))
    node_idx = num_targets
    for root in range(num_targets):
        p = root
        for _ in range(length):
            mat[node_idx, p] = 1
            mat[node_idx, root] = 1
            p = node_idx
            node_idx += 1
    masks = onp.ones((num_actions, num_preds))
    return num_preds, mat, masks, onp.arange(num_preds)


def discounted_return(discount_factors, num_actions, num_targets):
    num_preds = len(discount_factors) * num_targets
    mat = onp.zeros((num_targets + num_preds, num_targets + num_preds))
    node_idx = num_targets
    for root in range(num_targets):
        for _, gamma in enumerate(discount_factors):
            mat[node_idx, node_idx] = gamma
            mat[node_idx, root] = 1
            node_idx += 1
    masks = onp.ones((num_actions, num_preds))
    return num_preds, mat, masks


def bfs_planning(num_actions, depth, num_targets, gamma):
    total_nodes = num_targets
    que = collections.deque()
    dep = [0] * num_targets
    parent = [None] * num_targets
    root = [i for i in range(num_targets)]
    cond = [None] * num_targets
    for i in range(num_targets):
        u = total_nodes
        total_nodes += 1
        dep.append(0)
        parent.append(u)
        root.append(i)
        cond.append(None)
        que.append(u)
    while que:
        p = que.popleft()
        if dep[p] == depth:
            continue
        for a in range(num_actions):
            u = total_nodes
            total_nodes += 1
            dep.append(dep[p] + 1)
            parent.append(p)
            root.append(root[p])
            cond.append(a)
            que.append(u)
    num_preds = total_nodes - num_targets
    mat = onp.zeros((total_nodes, total_nodes))
    masks = onp.zeros((num_actions, num_preds))
    for i in range(total_nodes):
        if parent[i] is not None:
            mat[i, parent[i]] = gamma
            mat[i, root[i]] = 1.
            if cond[i] is None:
                masks[:, i - num_targets] = 1.
            else:
                masks[cond[i], i - num_targets] = 1.
    return num_preds, mat, masks, onp.array(dep[num_targets:])


def random_planning(seed, depth, repeat, gamma, num_targets, num_actions):
    assert num_targets == 1
    total_nodes = num_targets * (1 + 1 + depth * repeat * num_actions)
    mat = onp.zeros((total_nodes, total_nodes))
    masks = onp.zeros((num_actions, total_nodes))
    dep = onp.zeros((total_nodes,), dtype=onp.int32)

    rng = onp.random.RandomState(seed=seed)
    idx = num_targets
    for f in range(num_targets):
        nodes_d = []
        mat[idx, idx] = gamma
        mat[idx, f] = 1.
        masks[:, idx] = 1
        dep[idx] = 0
        nodes_d.append(idx)
        idx += 1
        for d in range(depth):
            nodes_dp1 = []
            for a in range(num_actions):
                k = min(len(nodes_d), repeat)
                parents = rng.choice(nodes_d, size=k, replace=False)
                for p in parents:
                    mat[idx, p] = gamma
                    mat[idx, f] = 1
                    masks[a, idx] = 1
                    dep[idx] = d + 1
                    nodes_dp1.append(idx)
                    idx += 1
            nodes_d = nodes_dp1
    num_preds = total_nodes - num_targets
    masks = masks[:, num_targets:]
    dep = dep[num_targets:]
    return num_preds, mat, masks, dep


def cond_tree_sum(depth, num_actions, num_targets, balance_by_depth):
    total_nodes = 0
    que = collections.deque()
    dep = []
    parent = []
    root = []
    cond = []
    for i in range(num_targets):
        que.append(i)
        dep.append(0)
        parent.append(None)
        root.append(i)
        cond.append(None)
        total_nodes += 1
    while que:
        p = que.popleft()
        if dep[p] == depth:
            continue
        for a in range(num_actions):
            c = total_nodes
            que.append(c)
            dep.append(dep[p] + 1)
            parent.append(p)
            root.append(root[p])
            cond.append(a)
            total_nodes += 1
    num_preds = total_nodes - num_targets
    mat = onp.zeros((total_nodes, total_nodes))
    masks = onp.zeros((num_actions, num_preds))
    for i in range(total_nodes):
        if parent[i] is not None:
            mat[i, parent[i]] = 1.
            mat[i, root[i]] = 1.
            masks[cond[i], i - num_targets] = 1.
    if balance_by_depth:
        # 1. Balance by depth: divide the coefficients by the number of nodes in each depth.
        count_depth = onp.zeros((depth,))
        for i in range(num_preds):
            count_depth[dep[num_targets + i] - 1] += 1
        for i in range(num_preds):
            masks[:, i] /= count_depth[dep[num_targets + i] - 1]
        # 2. Divide the global coefficient by depth so the loss is robust to varying depths.
        masks /= depth
        # 3. Multiply the global coefficient (learning rate) by sqrt(sum(w_i^2)) to maintain the variance.
        masks /= onp.sqrt(onp.sum(onp.square(masks)))
    return num_preds, mat, masks, onp.array(dep[num_targets:]) - 1


def cond_deep_tree_sum(seed, depth, num_actions, num_targets, balance_by_depth):
    rng = onp.random.RandomState(seed)
    total_nodes = 0
    que = collections.deque()
    dep = []
    parent = []
    root = []
    cond = []
    for i in range(num_targets):
        que.append(i)
        dep.append(0)
        parent.append(None)
        root.append(i)
        cond.append(None)
        total_nodes += 1
    while que:
        p = que.popleft()
        if dep[p] == depth:
            continue
        expanded = []
        for a in range(num_actions):
            c = total_nodes
            expanded.append(c)
            dep.append(dep[p] + 1)
            parent.append(p)
            root.append(root[p])
            cond.append(a)
            total_nodes += 1
        next_expand = rng.choice(expanded)
        que.append(next_expand)
    num_preds = total_nodes - num_targets
    mat = onp.zeros((total_nodes, total_nodes))
    masks = onp.zeros((num_actions, num_preds))
    for i in range(total_nodes):
        if parent[i] is not None:
            mat[i, parent[i]] = 1.
            mat[i, root[i]] = 1.
            masks[cond[i], i - num_targets] = 1.
    if balance_by_depth:
        # 1. Balance by depth: divide the coefficients by the number of nodes in each depth.
        count_depth = onp.zeros((depth,))
        for i in range(num_preds):
            count_depth[dep[num_targets + i] - 1] += 1
        for i in range(num_preds):
            masks[:, i] /= count_depth[dep[num_targets + i] - 1]
        # 2. Divide the global coefficient by depth so the loss is robust to varying depths.
        masks /= depth
        # 3. Multiply the global coefficient (learning rate) by sqrt(sum(w_i^2)) to maintain the variance.
        masks /= onp.sqrt(onp.sum(onp.square(masks)))
    return num_preds, mat, masks, onp.array(dep[num_targets:]) - 1


def open_loop_planning(seed, depth, repeat, discount_factors, num_targets, num_actions):
    total_nodes = num_targets * (1 + len(discount_factors) + depth * repeat * num_actions)
    mat = onp.zeros((total_nodes, total_nodes))
    masks = onp.zeros((num_actions, total_nodes))
    dep = onp.zeros((total_nodes,), dtype=onp.int32)

    rng = onp.random.RandomState(seed=seed)
    idx = num_targets
    for f in range(num_targets):
        nodes_d = [f]
        for gamma in discount_factors:
            mat[idx, idx] = gamma
            mat[idx, f] = 1.
            masks[:, idx] = 1
            dep[idx] = 0
            nodes_d.append(idx)
            idx += 1
        for d in range(depth):
            nodes_dp1 = []
            for a in range(num_actions):
                parents = rng.choice(nodes_d, size=repeat, replace=False)
                for p in parents:
                    mat[idx, p] = 1
                    mat[idx, f] = 1
                    masks[a, idx] = 1
                    dep[idx] = d + 1
                    nodes_dp1.append(idx)
                    idx += 1
            nodes_d = nodes_dp1
    num_preds = total_nodes - num_targets
    masks = masks[:, num_targets:]
    dep = dep[num_targets:]
    return num_preds, mat, masks, dep


def mixed_open_loop_planning(seed, depth, repeat, discount_factors, num_targets, num_actions):
    total_nodes = num_targets * (1 + len(discount_factors)) + depth * repeat * num_actions
    mat = onp.zeros((total_nodes, total_nodes))
    masks = onp.zeros((num_actions, total_nodes))
    dep = onp.zeros((total_nodes,), dtype=onp.int32)

    rng = onp.random.RandomState(seed=seed)
    idx = num_targets
    nodes_d = []
    for f in range(num_targets):
        nodes_d.append(f)
        for gamma in discount_factors:
            mat[idx, idx] = gamma
            mat[idx, f] = 1.
            masks[:, idx] = 1
            dep[idx] = 0
            nodes_d.append(idx)
            idx += 1
    for d in range(depth):
        nodes_dp1 = []
        for a in range(num_actions):
            # assert repeat <= len(nodes_d)
            _repeat = min(repeat, len(nodes_d))
            parents = rng.choice(nodes_d, size=_repeat, replace=False)
            for p in parents:
                mat[idx, p] = 1
                if p >= num_targets:
                    f = rng.choice(num_targets)
                    mat[idx, f] = 1
                masks[a, idx] = 1
                dep[idx] = d + 1
                nodes_dp1.append(idx)
                idx += 1
        nodes_d = nodes_dp1
    num_preds = total_nodes - num_targets
    masks = masks[:, num_targets:]
    dep = dep[num_targets:]
    return num_preds, mat, masks, dep


def mixed_on_policy(seed, depth, repeat, discount_factors, num_targets, num_actions):
    total_nodes = num_targets * (1 + len(discount_factors)) + depth * repeat * num_actions
    mat = onp.zeros((total_nodes, total_nodes))
    masks = onp.zeros((num_actions, total_nodes))
    dep = onp.zeros((total_nodes,), dtype=onp.int32)

    rng = onp.random.RandomState(seed=seed)
    idx = num_targets
    nodes_d = []
    for f in range(num_targets):
        for gamma in discount_factors:
            mat[idx, idx] = gamma
            mat[idx, f] = 1.
            masks[:, idx] = 1
            dep[idx] = 0
            nodes_d.append(idx)
            idx += 1
    for d in range(depth):
        nodes_dp1 = []
        for a in range(num_actions):
            assert repeat <= len(nodes_d)
            parents = rng.choice(nodes_d, size=repeat, replace=False)
            for p in parents:
                mat[idx, p] = 1
                if p >= num_targets:
                    f = rng.choice(num_targets)
                    mat[idx, f] = 1
                masks[:, idx] = 1
                dep[idx] = d + 1
                nodes_dp1.append(idx)
                idx += 1
        nodes_d = nodes_dp1
    num_preds = total_nodes - num_targets
    masks = masks[:, num_targets:]
    dep = dep[num_targets:]
    return num_preds, mat, masks, dep


FACTORY = {
    'chain_sum': chain_sum,
    'discounted_return': discounted_return,
    'cond_tree_sum': cond_tree_sum,
    'cond_deep_tree_sum': cond_deep_tree_sum,
    'open_loop_planning': open_loop_planning,
    'mixed_open_loop_planning': mixed_open_loop_planning,
    'mixed_on_policy': mixed_on_policy,
    'bfs_planning': bfs_planning,
    'random_planning': random_planning,
}


if __name__ == '__main__':
    num_preds, td_mat, td_masks, dep = mixed_on_policy(
        seed=1,
        depth=2,
        repeat=2,
        discount_factors=(0.95,),
        num_targets=2,
        num_actions=2,
    )
    print(num_preds)
    print(td_mat)
    print(td_masks)
    print(dep)
