# Modified from the original BEVFormer code
import torch
from torch import nn

from mmdet3d.registry import MODELS
from typing import Dict, List, Optional, Tuple

@MODELS.register_module()
class BEVFormer(nn.Module):
    def __init__(self,
                 *args,
                 pc_range,
                 embed_dims,
                 bev_h=180,
                 bev_w=180,
                 transformer: Optional[dict] = None,
                 positional_encoding: Optional[dict] = None,
                 **kwargs):
        super().__init__()
        self.pc_range = pc_range
        self.bev_h = bev_h
        self.bev_w = bev_w
        self.embed_dims = embed_dims
        self.real_w = self.pc_range[3] - self.pc_range[0]
        self.real_h = self.pc_range[4] - self.pc_range[1]

        self.transformer = MODELS.build(
            transformer) if transformer is not None else None

        self.positional_encoding = MODELS.build(
            positional_encoding) if positional_encoding is not None else None

        self.bev_embedding = nn.Embedding( \
            self.bev_h * self.bev_w, self.embed_dims)

    def init_weights(self):
        self.transformer.init_weights()
    
    # delete temporal attention operation
    def forward(self,
            img,             
            points,
            lidar2image,
            camera_intrinsics,
            camera2lidar,
            img_aug_matrix,
            lidar_aug_matrix,
            metas,
            **kwargs,):
        """Forward function.
        Args:
            mlvl_feats (tuple[Tensor]): Features from the upstream
                network, each is a 5D-tensor with shape
                (B, N, C, H, W).
            prev_bev: previous bev featues
            only_bev: only compute BEV features with encoder. 
        Returns:
            all_cls_scores (Tensor): Outputs from the classification head, \
                shape [nb_dec, bs, num_query, cls_out_channels]. Note \
                cls_out_channels should includes background.
            all_bbox_preds (Tensor): Sigmoid outputs from the regression \
                head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \
                Shape [nb_dec, bs, num_query, 9].
        """
        bs, num_cam, _, _, _ = img[0].shape
        dtype = img[0].dtype
        bev_queries = self.bev_embedding.weight.to(dtype)

        bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
                               device=bev_queries.device).to(dtype)
        bev_pos = self.positional_encoding(bev_mask).to(dtype)

        bev_embed = self.transformer.get_bev_features(
                img,
                bev_queries,
                self.bev_h,
                self.bev_w,
                grid_length=(self.real_h / self.bev_h,
                             self.real_w / self.bev_w),
                bev_pos=bev_pos,
                lidar2image=lidar2image,
                lidar_aug_matrix=lidar_aug_matrix,
                metas=metas
        )

        return bev_embed.permute(0, 2, 1).reshape(bs, -1, self.bev_h, self.bev_w,)
