import numpy as np

def main(grid):
    # Find the indices of non-zero elements in the grid
    non_zero_indices = np.argwhere(grid != 0)
    
    # Loop through each non-zero element in the grid
    for i, j in non_zero_indices:
        # Check if the current element is surrounded by other non-zero elements
        if (i > 0 and grid[i-1, j] != 0) and \
           (i < grid.shape[0]-1 and grid[i+1, j] != 0) and \
           (j > 0 and grid[i, j-1] != 0) and \
           (j < grid.shape[1]-1 and grid[i, j+1] != 0):
            # If the current element is surrounded, return its value
            return [[grid[i, j]]]
    
    # If no surrounded element is found, return None
    return None