import torch 
import numpy as np

# Torch version 1.9.0
# Numpy version 1.19.5

def feature_connectivity_batched(adj, i):
    """
    Parallel computation of feature connectivity of each pixel in i to all other pixels in their respective images

    adj: [B * N * N] 8-connectivity adjacency matrix, where B is the batch size and N = Height * Width of image.
    i: [B] Batch of indices of pixels 
    """
    b, n = adj.shape[:-1]
    b_arange = torch.arange(b)
    visited = torch.zeros((b, n)).to(adj.device)
    feature_connectivity = (torch.tensor(float('inf')) * torch.ones((b, n))).to(adj.device)
    zero_path = torch.zeros(b).to(adj.device)
    feature_connectivity[b_arange, i] = 0 # Init dist of index node to 0
    cur_dist, cur_idx = zero_path, i
    for i in range(n):
        visited[b_arange, cur_idx] = 1
        cur_dists = adj[b_arange, cur_idx]
        dists_from_cur = cur_dist.unsqueeze(1).expand(-1, n) + cur_dists # Repeat shortest dist to current node across all nodes as placeholders, then sum with adjacency distance
        feature_connectivity = torch.where(dists_from_cur < feature_connectivity, dists_from_cur, feature_connectivity) # Parallel Dijkstra update
        feature_connectivity_filter = torch.where(visited.bool(), torch.tensor(float('inf')).to(adj.device), feature_connectivity) # Set visited nodes to inf dist
        cur_dist, cur_idx = feature_connectivity_filter.min(dim=1) # Get next closest node
    return feature_connectivity


if __name__ == '__main__':
    print('In this demo, we give a sample code implementation that we use to compute feature connectivity in parallel for batches of images. Refer to Algorithm 1 lines 7-24, without clustering. \n')
    print('First, we define 8-connectivity adjacency matrices for 2 3x3 images, where each non-infinite value denotes the inverse feature similarity between pixel neighbors.\n')
    inf = torch.tensor(float('inf'))
    adj = torch.tensor([[[  inf, 0.191,   inf, 0.073, 0.647,   inf,   inf,   inf,   inf],
                        [0.191,   inf, 0.056, 0.118, 0.456, 0.105,   inf,   inf,   inf],
                        [  inf, 0.056,   inf,   inf, 0.400, 0.161,   inf,   inf,   inf],
                        [0.073, 0.118,   inf,   inf, 0.574,   inf, 0.413, 0.271,   inf],
                        [0.647, 0.456, 0.400, 0.574,   inf, 0.560, 0.161, 0.303, 0.709],
                        [  inf, 0.105, 0.161,   inf, 0.560,   inf,   inf, 0.257, 0.149],
                        [  inf,   inf,   inf, 0.413, 0.161,   inf,   inf, 0.142,   inf],
                        [  inf,   inf,   inf, 0.271, 0.303, 0.257, 0.142,   inf, 0.406],
                        [  inf,   inf,   inf,   inf, 0.709, 0.149,   inf, 0.406,   inf]],

                        [[  inf, 0.255,   inf, 0.531, 0.268,   inf,   inf,   inf,   inf],
                        [0.255,   inf, 0.578, 0.276, 0.014, 0.633,   inf,   inf,   inf],
                        [  inf, 0.578,   inf,   inf, 0.592, 0.055,   inf,   inf,   inf],
                        [0.531, 0.276,   inf,   inf, 0.262,   inf, 0.966, 0.040,   inf],
                        [0.268, 0.014, 0.592, 0.262,   inf, 0.647, 0.704, 0.222, 0.423],
                        [  inf, 0.633, 0.055,   inf, 0.647,   inf,   inf, 0.869, 0.224],
                        [  inf,   inf,   inf, 0.966, 0.704,   inf,   inf, 0.926,   inf],
                        [  inf,   inf,   inf, 0.040, 0.222, 0.869, 0.926,   inf, 0.644],
                        [  inf,   inf,   inf,   inf, 0.423, 0.224,   inf, 0.644,   inf]]])
    print('8-connectivity Adjacency Matrix:')
    print(adj.numpy(), '\n')
    print('Use two indices to denote the sampled pixel for each image')
    i = [4, 7]
    print(i, '\n')
    feature_connectivity = feature_connectivity_batched(adj, i)
    print('Feature Connectivity vectors for each sampled pixel respectively:')
    print(np.round(feature_connectivity.numpy(), 3))
    
    # Output of i=[4,7] should be:
    # [[0.647 0.456 0.4   0.574 0.    0.56  0.161 0.303 0.709]
    # [0.49  0.236 0.814 0.04  0.222 0.868 0.926 0.    0.644]]


