import numpy as np

black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    output_grid = np.copy(input_grid)
    rows, cols = input_grid.shape

    for i in range(rows):
        for j in range(cols):
            if input_grid[i][j] != black:
                adjacent_points = [(i-1, j-1), (i-1, j), (i-1, j+1), (i, j-1), (i, j+1), (i+1, j-1), (i+1, j), (i+1, j+1)]
                all_black = True
                for point in adjacent_points:
                    if point[0] < 0 or point[0] >= rows or point[1] < 0 or point[1] >= cols:
                        continue
                    if input_grid[point[0]][point[1]] != black:
                        all_black = False
                        break
                if all_black:
                    output_grid[i][j] = black

    return output_grid


