import numpy as np
black, blue, red, green, yellow, grey, pink, orange, teal, maroon = range(10)

def main(input_grid: np.ndarray) -> np.ndarray:
    # Find the up, down, left, right range of the non-black pixels
    non_black_pixels = np.where(input_grid != black)
    up, down = np.min(non_black_pixels[0]), np.max(non_black_pixels[0])
    left, right = np.min(non_black_pixels[1]), np.max(non_black_pixels[1])
    
    # Make the black pixels in the up, down, left, right range red
    output_grid = input_grid.copy()
    output_grid[up:down+1, left:right+1][output_grid[up:down+1, left:right+1] == black] = red
    
    return output_grid