import torch
import torch.nn.functional as F
import time

# single_batch = torch.randn(15, 64, 64, 320).cuda()
# window_size = 5
# k = 4
# start_time = time.time()

# coord_result = torch.zeros(15, 64, 64, 15, k, 2)
# pad_size = window_size//2
# padded_batch = F.pad(single_batch, (0, 0, pad_size, pad_size, pad_size, pad_size, 1, 1), value=float('-inf'))
# saved_topk_position = []
# Fr, H, W, d = padded_batch.shape

# a=torch.randn(15*64*64, 320, dtype=torch.float16).cuda()
# b=torch.randn(320, 15*64*64, dtype=torch.float16).cuda()
# hidden_states=torch.randn(15,64,64,320)
single_batch = torch.randn(15,64,64,320).cuda()
F, H, W, d = single_batch.shape
single_batch = single_batch.reshape(-1, d)
single_batch_T = single_batch.t()
k = 5  # Top k elements to keep

# 预分配 tensors
total_elements = F * H * W
values = torch.zeros(total_elements, F, k, device=single_batch.device).cuda()  # [15*64*64, 15, k]
positions = torch.zeros(total_elements, F, k, 2, device=single_batch.device, dtype=torch.long).cuda()  # [15*64*64, 15, k, 2]

# 对每个元素执行操作
for idx, pos in enumerate(single_batch):
    sim = pos @ single_batch_T
    sim = sim.reshape(F, H, W)
    reshaped_sim = sim.view(sim.size(0), -1)

    # 使用 topk 获取最大的 k 个值及其索引
    value, indices = reshaped_sim.topk(k, dim=1)

    # 计算索引在原始 64x64 图像中的位置
    rows = indices // sim.size(2)
    cols = indices % sim.size(2)

    # 将行列索引组合起来，形成最终的位置 tensor
    position = torch.stack((rows, cols), dim=2)

    # 填充预分配的 tensors
    # values[idx] = value
    positions[idx] = position

# values = values.reshape(F, H, W, F, k)
positions = positions.reshape(F, H, W, F, k, 2) # [15, 64, 64, 15, 5, 2]
single_batch = single_batch.reshape(F, H, W, d)
# F, H, W, D = single_batch.shape  # 获取 single_batch 的尺寸
# Reshape positions 为 [15*64*64, 15, n, 2] 方便后续操作
positions_reshaped = positions.view(-1, *positions.shape[-3:])

# 生成额外的维度索引，以便我们可以利用高级索引从 single_batch 中提取特定的元素
frame_indices = torch.arange(F).view(-1, 1, 1, 1).expand(-1, H, W, 1).reshape(-1, 1, 1).to(positions.device)
h_indices = positions_reshaped[..., 0]  # 高度坐标
w_indices = positions_reshaped[..., 1]  # 宽度坐标

# 使用高级索引提取值
extracted_values = single_batch[frame_indices, h_indices, w_indices, :]

# 将 extracted_values 的形状调整为最终所需的形状 [15, 64, 64, 15, n, 320]
result = extracted_values.view(F, H, W, F, -1, d)

print("result:",result.shape)
# print("positions:",positions.shape)

        
        
    
        

# for cur_f in range(1, Fr-1):
#     print(cur_f)
#     for cur_h in range(pad_size, H-pad_size):
#         for cur_w in range(pad_size, W-pad_size):
# #             print(cur_h)
#             neighbor = padded_batch[cur_f-1:cur_f+2, cur_h-pad_size:cur_h+pad_size+1, cur_w-pad_size:cur_w+pad_size+1] #[3, 5, 5, 320]
#             # neighbor = neighbor.reshape(-1, d)  #[75, 320]
#             cur_feature = padded_batch[cur_f, cur_h, cur_w].reshape(d, 1) #[320, 1]
            
#             similarity = (neighbor @ cur_feature).view(3, window_size, window_size)
#             nan_mask = similarity.isnan()
#             similarity[nan_mask]=float('-inf')
            
#             values, indices = torch.topk(similarity.view(3, -1), k, dim=1)
            
#             x_coords = torch.clamp(indices % window_size + cur_w - pad_size, pad_size, W-pad_size-1)# 列索引
#             y_coords = torch.clamp(indices // window_size + cur_h - pad_size, pad_size, H-pad_size-1) # 行索引
#             # 合并 x 和 y 坐标以得到最终的位置 tensor
#             coords = torch.stack((y_coords, x_coords), dim=2).unsqueeze(0) #[1,3,4,2]
            

#             for previous_f in range(cur_f-1, 1, -1):
#                 new_center_y, new_center_x = coords[0,0,0,:]
#                 #第i-2帧中的neighbor
#                 pre_neighbor = padded_batch[previous_f-1, new_center_x-pad_size:new_center_x+pad_size+1, new_center_y-pad_size:new_center_y+pad_size+1]#.reshape(-1, d)
#                 #第i-1帧中的center feature
#                 pre_cur_feature = padded_batch[previous_f, new_center_x, new_center_y].reshape(d, 1)
                
#                 similarity = (pre_neighbor @ pre_cur_feature).view(1, window_size, window_size)

#                 values, indices = torch.topk(similarity.view(1, -1), k, dim=1)
#                 x_coords = torch.clamp(indices % window_size + cur_w - pad_size, pad_size, W-pad_size-1)# 列索引
#                 y_coords = torch.clamp(indices // window_size + cur_h - pad_size, pad_size, H-pad_size-1) # 行索引
#                 coords_pre = torch.stack((y_coords, x_coords), dim=2).unsqueeze(0)                
#                 coords = torch.cat((coords_pre, coords), dim=1)


#             for future_f in range(cur_f+1, Fr-2):
#                 new_center_y, new_center_x = coords[0,-1,0,:]
#                 #第i+2帧中的neighbor
#                 fut_neighbor = padded_batch[future_f+1, new_center_x-pad_size:new_center_x+pad_size+1, new_center_y-pad_size:new_center_y+pad_size+1]#.reshape(-1, d)
#                 #第i+1帧中的center feature
#                 fut_cur_feature = padded_batch[future_f, new_center_x, new_center_y].reshape(d, 1)
                
#                 similarity = (fut_neighbor @ fut_cur_feature).view(1, window_size, window_size)

#                 values, indices = torch.topk(similarity.view(1, -1), k, dim=1)
#                 x_coords = torch.clamp(indices % window_size + cur_w - pad_size, pad_size, W-pad_size-1)# 列索引
#                 y_coords = torch.clamp(indices // window_size + cur_h - pad_size, pad_size, H-pad_size-1) # 行索引
#                 coords_fut = torch.stack((y_coords, x_coords), dim=2).unsqueeze(0)
#                 coords = torch.cat((coords, coords_fut), dim=1)
            
#             if cur_f == 1:
#                 coords = coords[:,1:,:,:]
#             if cur_f == Fr-2:
#                 coords = coords[:,:-1,:,:]
            
#             coord_result[cur_f-1, cur_w-pad_size, cur_h-pad_size] = coords
#             # print(coords.shape)
#             #1,f,k,2
            
# end_time = time.time()

# # 计算并打印运行时间
# elapsed_time = end_time - start_time
# print(f"代码运行时间：{elapsed_time}秒")
