from __future__ import division

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import normal_init
from mmdet.ops import DeformConv

from mmdet.core import (PointGenerator, point_target, point_hm_target, multi_apply, multiclass_nms)
from ..builder import build_loss
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule, TLPool, BRPool
from ..utils.norm import build_norm_layer


@HEADS.register_module
class RepPointsV2Head(nn.Module):
    """RepPoint head.

    Args:
        in_channels (int): Number of channels in the input feature map.
        feat_channels (int): Number of channels of the feature map.
        stacked_convs (int): How many conv layers are used.
        gradient_mul (float): The multiplier to gradients from points refinement and recognition.
        point_strides (Iterable): points strides.
        point_base_scale (int): bbox scale for assigning labels.
        transform_method (str): The methods to transform RepPoints to bbox.
    """  # noqa: W605

    def __init__(self,
                 num_classes,
                 in_channels,
                 feat_channels=256,
                 stacked_convs=3,
                 shared_stacked_convs=1,
                 num_points=9,
                 gradient_mul=0.1,
                 point_strides=[8, 16, 32, 64, 128],
                 point_base_scale=4,
                 conv_cfg=None,
                 norm_cfg=None,
                 loss_cls=dict(type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0),
                 loss_bbox_init=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
                 loss_bbox_refine=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
                 loss_heatmap=dict(type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=0.5),
                 loss_offset=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
                 loss_sem=dict(type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=0.5),
                 use_grid_points=False,
                 center_init=True,
                 transform_method='moment',
                 **kwargs):
        super(RepPointsV2Head, self).__init__()
        self.num_classes = num_classes
        self.cls_out_channels = self.num_classes - 1
        self.in_channels = in_channels
        self.feat_channels = feat_channels
        self.stacked_convs = stacked_convs
        self.shared_stacked_convs = shared_stacked_convs
        self.num_points = num_points
        self.gradient_mul = gradient_mul
        self.point_base_scale = point_base_scale
        self.point_strides = point_strides
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.loss_cls = build_loss(loss_cls)
        self.loss_bbox_init = build_loss(loss_bbox_init)
        self.loss_bbox_refine = build_loss(loss_bbox_refine)
        self.loss_heatmap = build_loss(loss_heatmap)
        self.loss_offset = build_loss(loss_offset)
        self.loss_sem = build_loss(loss_sem)
        self.use_grid_points = use_grid_points
        self.center_init = center_init
        self.transform_method = transform_method
        if self.transform_method == 'moment':
            self.moment_transfer = nn.Parameter(data=torch.zeros(2), requires_grad=True)

        self.point_generators = []
        for _ in self.point_strides:
            self.point_generators.append(PointGenerator())
        self._init_layers()

    def _init_dcn_offset(self, num_points):
        self.dcn_kernel = int(np.sqrt(num_points))
        self.dcn_pad = int((self.dcn_kernel - 1) / 2)
        assert self.dcn_kernel * self.dcn_kernel == num_points, "The points number should be a square number."
        assert self.dcn_kernel % 2 == 1, "The points number should be an odd square number."
        dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float)
        dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
        dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
        dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1))
        self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)

    def _init_layers(self):
        self._init_dcn_offset(self.num_points)
        self.relu = nn.ReLU(inplace=True)
        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()
        self.shared_convs = nn.ModuleList()
        self.lateral_convs = nn.ModuleList()

        for i in range(self.stacked_convs):
            chn = self.in_channels if i == 0 else self.feat_channels
            self.cls_convs.append(
                ConvModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg))
            self.reg_convs.append(
                ConvModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg))

        for i in range(self.shared_stacked_convs):
            self.shared_convs.append(
                ConvModule(
                    self.feat_channels,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg))

        self.hem_tl = TLPool(self.feat_channels, self.conv_cfg, self.norm_cfg)
        self.hem_br = BRPool(self.feat_channels, self.conv_cfg, self.norm_cfg)

        pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points

        cls_in_channels = self.feat_channels + 6
        self.reppoints_cls_conv = DeformConv(cls_in_channels, self.feat_channels, self.dcn_kernel, 1,
                                             self.dcn_pad)
        self.reppoints_cls_out = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1, 1, 0)

        self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, self.feat_channels, 3, 1, 1)
        self.reppoints_pts_init_out = nn.Conv2d(self.feat_channels, pts_out_dim, 1, 1, 0)
        pts_in_channels = self.feat_channels + 6
        self.reppoints_pts_refine_conv = DeformConv(pts_in_channels, self.feat_channels, self.dcn_kernel, 1,
                                                    self.dcn_pad)
        self.reppoints_pts_refine_out = nn.Conv2d(self.feat_channels, pts_out_dim, 1, 1, 0)

        self.reppoints_hem_tl_score_out = nn.Conv2d(self.feat_channels, 1, 3, 1, 1)
        self.reppoints_hem_br_score_out = nn.Conv2d(self.feat_channels, 1, 3, 1, 1)
        self.reppoints_hem_tl_offset_out = nn.Conv2d(self.feat_channels, 2, 3, 1, 1)
        self.reppoints_hem_br_offset_out = nn.Conv2d(self.feat_channels, 2, 3, 1, 1)

        self.reppoints_sem_out = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1, 1, 0)
        self.reppoints_sem_embedding = ConvModule(
            self.feat_channels,
            self.feat_channels,
            1,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg)

    def init_weights(self):
        for m in self.cls_convs:
            normal_init(m.conv, std=0.01)
        for m in self.reg_convs:
            normal_init(m.conv, std=0.01)
        for m in self.shared_convs:
            normal_init(m.conv, std=0.01)
        bias_cls = bias_init_with_prob(0.01)
        normal_init(self.reppoints_cls_conv, std=0.01)
        normal_init(self.reppoints_cls_out, std=0.01, bias=bias_cls)
        normal_init(self.reppoints_pts_init_conv, std=0.01)
        normal_init(self.reppoints_pts_init_out, std=0.01)
        normal_init(self.reppoints_pts_refine_conv, std=0.01)
        normal_init(self.reppoints_pts_refine_out, std=0.01)
        normal_init(self.reppoints_hem_tl_score_out, std=0.01, bias=bias_cls)
        normal_init(self.reppoints_hem_tl_offset_out, std=0.01)
        normal_init(self.reppoints_hem_br_score_out, std=0.01, bias=bias_cls)
        normal_init(self.reppoints_hem_br_offset_out, std=0.01)
        normal_init(self.reppoints_sem_out, std=0.01, bias=bias_cls)

    def transform_box(self, pts, y_first=True):
        if self.transform_method == 'minmax':
            pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
            pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...]
            pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...]
            bbox_left = pts_x.min(dim=1, keepdim=True)[0]
            bbox_right = pts_x.max(dim=1, keepdim=True)[0]
            bbox_up = pts_y.min(dim=1, keepdim=True)[0]
            bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
            bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], dim=1)
        elif self.transform_method == 'partial_minmax':
            pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
            pts_reshape = pts_reshape[:, :4, ...]
            pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...]
            pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...]
            bbox_left = pts_x.min(dim=1, keepdim=True)[0]
            bbox_right = pts_x.max(dim=1, keepdim=True)[0]
            bbox_up = pts_y.min(dim=1, keepdim=True)[0]
            bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
            bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], dim=1)
        elif self.transform_method == 'moment':
            pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
            pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...]
            pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...]
            pts_y_mean = pts_y.mean(dim=1, keepdim=True)
            pts_x_mean = pts_x.mean(dim=1, keepdim=True)
            pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
            pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
            moment_transfer = self.moment_transfer * 0.01 + self.moment_transfer.detach() * 0.99
            moment_width_transfer = moment_transfer[0]
            moment_height_transfer = moment_transfer[1]
            half_width = pts_x_std * torch.exp(moment_width_transfer)
            half_height = pts_y_std * torch.exp(moment_height_transfer)
            bbox = torch.cat([pts_x_mean - half_width, pts_y_mean - half_height,
                              pts_x_mean + half_width, pts_y_mean + half_height], dim=1)
        elif self.transform_method == "exact_minmax":
            pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
            pts_reshape = pts_reshape[:, :2, ...]
            pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...]
            pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...]
            bbox_left = pts_x[:, 0:1, ...]
            bbox_right = pts_x[:, 1:2, ...]
            bbox_up = pts_y[:, 0:1, ...]
            bbox_bottom = pts_y[:, 1:2, ...]
            bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], dim=1)
        else:
            raise NotImplementedError
        return bbox

    def gen_grid_from_reg(self, reg, previous_boxes):
        b, _, h, w = reg.shape
        tx = reg[:, [0], ...]
        ty = reg[:, [1], ...]
        tw = reg[:, [2], ...]
        th = reg[:, [3], ...]
        bx = (previous_boxes[:, [0], ...] + previous_boxes[:, [2], ...]) / 2.
        by = (previous_boxes[:, [1], ...] + previous_boxes[:, [3], ...]) / 2.
        bw = (previous_boxes[:, [2], ...] - previous_boxes[:, [0], ...]).clamp(min=1e-6)
        bh = (previous_boxes[:, [3], ...] - previous_boxes[:, [1], ...]).clamp(min=1e-6)
        grid_left = bx + bw * tx - 0.5 * bw * torch.exp(tw)
        grid_width = bw * torch.exp(tw)
        grid_up = by + bh * ty - 0.5 * bh * torch.exp(th)
        grid_height = bh * torch.exp(th)
        intervel = torch.linspace(0., 1., self.dcn_kernel).view(1, self.dcn_kernel, 1, 1).type_as(reg)
        grid_x = grid_left + grid_width * intervel
        grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
        grid_x = grid_x.view(b, -1, h, w)
        grid_y = grid_up + grid_height * intervel
        grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
        grid_y = grid_y.view(b, -1, h, w)
        grid_yx = torch.stack([grid_y, grid_x], dim=2)
        grid_yx = grid_yx.view(b, -1, h, w)
        regressed_bbox = torch.cat([grid_left, grid_up, grid_left + grid_width, grid_up + grid_height], 1)
        return grid_yx, regressed_bbox

    def forward_single(self, x):
        b, _, h, w = x.shape
        dcn_base_offset = self.dcn_base_offset.type_as(x)
        if self.use_grid_points or not self.center_init:
            scale = self.point_base_scale / 2
            points_init = dcn_base_offset / dcn_base_offset.max() * scale
            bbox_init = torch.tensor([-scale, -scale, scale, scale]).view(1, 4, 1, 1).type_as(x)
        else:
            points_init = 0

        cls_feat = x
        pts_feat = x
        for cls_conv in self.cls_convs:
            cls_feat = cls_conv(cls_feat)
        for reg_conv in self.reg_convs:
            pts_feat = reg_conv(pts_feat)

        shared_feat = pts_feat
        for shared_conv in self.shared_convs:
            shared_feat = shared_conv(shared_feat)

        sem_feat = shared_feat
        hem_feat = shared_feat

        sem_scores_out = self.reppoints_sem_out(sem_feat)
        sem_feat = self.reppoints_sem_embedding(sem_feat)

        cls_feat = cls_feat + sem_feat
        pts_feat = pts_feat + sem_feat
        hem_feat = hem_feat + sem_feat

        # generate heatmap and offset
        hem_tl_feat = self.hem_tl(hem_feat)
        hem_br_feat = self.hem_br(hem_feat)

        hem_tl_score_out = self.reppoints_hem_tl_score_out(hem_tl_feat)
        hem_tl_offset_out = self.reppoints_hem_tl_offset_out(hem_tl_feat)
        hem_br_score_out = self.reppoints_hem_br_score_out(hem_br_feat)
        hem_br_offset_out = self.reppoints_hem_br_offset_out(hem_br_feat)

        hem_score_out = torch.cat([hem_tl_score_out, hem_br_score_out], dim=1)
        hem_offset_out = torch.cat([hem_tl_offset_out, hem_br_offset_out], dim=1)

        # initialize reppoints
        pts_out_init = self.reppoints_pts_init_out(self.relu(self.reppoints_pts_init_conv(pts_feat)))
        if self.use_grid_points:
            pts_out_init, bbox_out_init = self.gen_grid_from_reg(pts_out_init, bbox_init.detach())
        else:
            pts_out_init = pts_out_init + points_init

        # refine and classify reppoints
        pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach() + self.gradient_mul * pts_out_init
        dcn_offset = pts_out_init_grad_mul - dcn_base_offset

        hem_feat = torch.cat([hem_score_out, hem_offset_out], dim=1)
        cls_feat = torch.cat([cls_feat, hem_feat], dim=1)
        pts_feat = torch.cat([pts_feat, hem_feat], dim=1)

        cls_out = self.reppoints_cls_out(self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
        pts_out_refine = self.reppoints_pts_refine_out(self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))

        if self.use_grid_points:
            pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(pts_out_refine, bbox_out_init.detach())
        else:
            pts_out_refine = pts_out_refine + pts_out_init.detach()

        return cls_out, pts_out_init, pts_out_refine, hem_score_out, hem_offset_out, sem_scores_out

    def forward(self, feats):
        cls_out_list, pts_out_init_list, pts_out_refine_list, hem_score_out_list, hem_offset_out_list, sem_scores_out \
            = multi_apply(self.forward_single, feats)

        return cls_out_list, pts_out_init_list, pts_out_refine_list, \
               hem_score_out_list, hem_offset_out_list, sem_scores_out

    def get_points(self, featmap_sizes, img_metas):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            img_metas (list[dict]): Image meta info.

        Returns:
            tuple: points of each image, valid flags of each image
        """
        num_imgs = len(img_metas)
        num_levels = len(featmap_sizes)

        # since feature map sizes of all images are the same, we only compute
        # points center for one time
        multi_level_points = []
        for i in range(num_levels):
            points = self.point_generators[i].grid_points(featmap_sizes[i], self.point_strides[i])
            multi_level_points.append(points)
        points_list = [[point.clone() for point in multi_level_points] for _ in range(num_imgs)]

        # for each image, we compute valid flags of multi level grids
        valid_flag_list = []
        for img_id, img_meta in enumerate(img_metas):
            multi_level_flags = []
            for i in range(num_levels):
                point_stride = self.point_strides[i]
                feat_h, feat_w = featmap_sizes[i]
                h, w, _ = img_meta['pad_shape']
                valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h)
                valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w)
                flags = self.point_generators[i].valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w))
                multi_level_flags.append(flags)
            valid_flag_list.append(multi_level_flags)

        return points_list, valid_flag_list

    def centers_to_bboxes(self, point_list):
        """Get bboxes according to center points. Only used in MaxIOUAssigner.
        """
        bbox_list = []
        for i_img, point in enumerate(point_list):
            bbox = []
            for i_lvl in range(len(self.point_strides)):
                scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
                bbox_shift = torch.Tensor([-scale, -scale, scale, scale]).view(1, 4).type_as(point[0])
                bbox_center = torch.cat([point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
                bbox.append(bbox_center + bbox_shift)
            bbox_list.append(bbox)
        return bbox_list

    def yx_to_xy(self, pts):
        """Change the points offset from y first to x first.
        """
        pts_y = pts[..., 0::2]
        pts_x = pts[..., 1::2]
        pts_xy = torch.stack([pts_x, pts_y], -1)
        pts = pts_xy.view(*pts.shape[:-1], -1)
        return pts

    def offset_to_pts(self, center_list, pred_list):
        """Change from point offset to point coordinate.
        """
        pts_list = []
        for i_lvl in range(len(self.point_strides)):
            pts_lvl = []
            for i_img in range(len(center_list)):
                pts_center = center_list[i_img][i_lvl][:, :2].repeat(1, self.num_points)
                pts_shift = pred_list[i_lvl][i_img]
                yx_pts_shift = pts_shift.permute(1, 2, 0).view(-1, 2 * self.num_points)
                xy_pts_shift = self.yx_to_xy(yx_pts_shift)
                pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
                pts_lvl.append(pts)
            pts_lvl = torch.stack(pts_lvl, 0)
            pts_list.append(pts_lvl)
        return pts_list

    def loss(self,
             cls_scores,
             pts_preds_init,
             pts_preds_refine,
             hm_scores,
             hm_offsets,
             sem_scores,
             gt_bboxes,
             gt_sem_map,
             gt_labels,
             img_metas,
             cfg,
             gt_sem_weights=None,
             gt_bboxes_ignore=None):
        assert len(cls_scores) == len(self.point_generators)
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]

        # target for initial stage
        proposal_list, valid_flag_list = self.get_points(featmap_sizes, img_metas)
        pts_coordinate_preds_init = self.offset_to_pts(proposal_list, pts_preds_init)
        if cfg.init.assigner['type'] == 'MaxIoUAssigner':
            proposal_list = self.centers_to_bboxes(proposal_list)
        cls_reg_targets_init = point_target(
            proposal_list,
            valid_flag_list,
            gt_bboxes,
            gt_labels,
            cfg.init)

        (*_, bbox_gt_list_init, bbox_weights_list_init,
         num_total_pos_init, num_total_neg_init) = cls_reg_targets_init

        # target for heatmap in initial stage
        proposal_list, valid_flag_list = self.get_points(featmap_sizes, img_metas)
        heatmap_targets = point_hm_target(
            proposal_list,
            valid_flag_list,
            gt_bboxes,
            gt_labels,
            cfg.heatmap)
        (gt_hm_tl_list, gt_offset_tl_list, gt_hm_tl_weight_list, gt_offset_tl_weight_list,
         gt_hm_br_list, gt_offset_br_list, gt_hm_br_weight_list, gt_offset_br_weight_list,
         num_total_pos_tl, num_total_neg_tl, num_total_pos_br, num_total_neg_br) = heatmap_targets

        # target for refinement stage
        proposal_list, valid_flag_list = self.get_points(featmap_sizes, img_metas)
        pts_coordinate_preds_refine = self.offset_to_pts(proposal_list, pts_preds_refine)
        bbox_list = []
        for i_img, point in enumerate(proposal_list):
            bbox = []
            for i_lvl in range(len(pts_preds_refine)):
                bbox_preds_init = self.transform_box(pts_preds_init[i_lvl].detach())
                bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
                bbox_center = torch.cat([point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
                bbox.append(bbox_center + bbox_shift[i_img].permute(1, 2, 0).contiguous().view(-1, 4))
            bbox_list.append(bbox)
        cls_reg_targets_refine = point_target(
            bbox_list,
            valid_flag_list,
            gt_bboxes,
            gt_labels,
            cfg.refine)

        (labels_list, label_weights_list, bbox_gt_list_refine, bbox_weights_list_refine,
         num_total_pos_refine, num_total_neg_refine) = cls_reg_targets_refine

        # compute loss
        losses_cls, losses_pts_init, losses_pts_refine, losses_heatmap, losses_offset = multi_apply(
            self.loss_single,
            cls_scores,
            pts_coordinate_preds_init,
            pts_coordinate_preds_refine,
            hm_scores,
            hm_offsets,
            labels_list,
            label_weights_list,
            bbox_gt_list_init,
            bbox_weights_list_init,
            bbox_gt_list_refine,
            bbox_weights_list_refine,
            gt_hm_tl_list,
            gt_offset_tl_list,
            gt_hm_tl_weight_list,
            gt_offset_tl_weight_list,
            gt_hm_br_list,
            gt_offset_br_list,
            gt_hm_br_weight_list,
            gt_offset_br_weight_list,
            self.point_strides,
            num_total_samples_init=num_total_pos_init,
            num_total_samples_refine=num_total_pos_refine,
            num_total_samples_tl=num_total_pos_tl,
            num_total_samples_br=num_total_pos_br)

        # sem loss
        concat_sem_scores = []
        concat_gt_sem_map = []
        concat_gt_sem_weights = []

        for i in range(5):
            sem_score = sem_scores[i]
            gt_lvl_sem_map = F.interpolate(gt_sem_map, sem_score.shape[-2:]).reshape(-1)
            gt_lvl_sem_weight = F.interpolate(gt_sem_weights, sem_score.shape[-2:]).reshape(-1)
            sem_score = sem_score.reshape(-1)

            try:
                concat_sem_scores = torch.cat([concat_sem_scores, sem_score])
                concat_gt_sem_map = torch.cat([concat_gt_sem_map, gt_lvl_sem_map])
                concat_gt_sem_weights = torch.cat([concat_gt_sem_weights, gt_lvl_sem_weight])
            except:
                concat_sem_scores = sem_score
                concat_gt_sem_map = gt_lvl_sem_map
                concat_gt_sem_weights = gt_lvl_sem_weight

        loss_sem = self.loss_sem(concat_sem_scores, concat_gt_sem_map, concat_gt_sem_weights, avg_factor=(concat_gt_sem_map > 0).sum())

        loss_dict_all = {'loss_cls': losses_cls,
                         'loss_pts_init': losses_pts_init,
                         'loss_pts_refine': losses_pts_refine,
                         'loss_heatmap': losses_heatmap,
                         'loss_offset': losses_offset,
                         'loss_sem': loss_sem,
                         }
        return loss_dict_all

    def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, hm_score, hm_offset,
                    labels, label_weights,
                    bbox_gt_init, bbox_weights_init,
                    bbox_gt_refine, bbox_weights_refine,
                    gt_hm_tl, gt_offset_tl, gt_hm_tl_weight, gt_offset_tl_weight,
                    gt_hm_br, gt_offset_br, gt_hm_br_weight, gt_offset_br_weight,
                    stride,
                    num_total_samples_init, num_total_samples_refine,
                    num_total_samples_tl, num_total_samples_br):
        # classification loss
        labels = labels.reshape(-1)
        label_weights = label_weights.reshape(-1)
        cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
        loss_cls = self.loss_cls(
            cls_score, labels, label_weights, avg_factor=num_total_samples_refine)

        # points loss
        bbox_gt_init = bbox_gt_init.reshape(-1, 4)
        bbox_weights_init = bbox_weights_init.reshape(-1, 4)
        bbox_pred_init = self.transform_box(pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
        bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
        bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
        bbox_pred_refine = self.transform_box(pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
        normalize_term = self.point_base_scale * stride
        loss_pts_init = self.loss_bbox_init(
            bbox_pred_init / normalize_term,
            bbox_gt_init / normalize_term,
            bbox_weights_init,
            avg_factor=num_total_samples_init)
        loss_pts_refine = self.loss_bbox_refine(
            bbox_pred_refine / normalize_term,
            bbox_gt_refine / normalize_term,
            bbox_weights_refine,
            avg_factor=num_total_samples_refine)

        # heatmap cls loss
        hm_score = hm_score.permute(0, 2, 3, 1).reshape(-1, 2)
        hm_score_tl, hm_score_br = torch.chunk(hm_score, 2, dim=-1)

        gt_hm_tl = gt_hm_tl.reshape(-1)
        gt_hm_tl_weight = gt_hm_tl_weight.reshape(-1)
        gt_hm_br = gt_hm_br.reshape(-1)
        gt_hm_br_weight = gt_hm_br_weight.reshape(-1)

        loss_heatmap = 0
        loss_heatmap += self.loss_heatmap(
            hm_score_tl, gt_hm_tl, gt_hm_tl_weight, avg_factor=num_total_samples_tl
        )
        loss_heatmap += self.loss_heatmap(
            hm_score_br, gt_hm_br, gt_hm_br_weight, avg_factor=num_total_samples_br
        )
        loss_heatmap /= 2.0

        # heatmap offset loss
        hm_offset = hm_offset.permute(0, 2, 3, 1).reshape(-1, 4)
        hm_offset_tl, hm_offset_br = torch.chunk(hm_offset, 2, dim=-1)

        gt_offset_tl = gt_offset_tl.reshape(-1, 2)
        gt_offset_tl_weight = gt_offset_tl_weight.reshape(-1, 2)
        gt_offset_br = gt_offset_br.reshape(-1, 2)
        gt_offset_br_weight = gt_offset_br_weight.reshape(-1, 2)

        loss_offset = 0
        loss_offset += self.loss_offset(
            hm_offset_tl, gt_offset_tl, gt_offset_tl_weight,
            avg_factor=num_total_samples_tl
        )
        loss_offset += self.loss_offset(
            hm_offset_br, gt_offset_br, gt_offset_br_weight,
            avg_factor=num_total_samples_br
        )
        loss_offset /= 2.0

        return loss_cls, loss_pts_init, loss_pts_refine, loss_heatmap, loss_offset

    def get_bboxes(self,
                   cls_scores,
                   pts_preds_init,
                   pts_preds_refine,
                   hm_scores,
                   hm_offsets,
                   sem_scores,
                   img_metas,
                   cfg,
                   rescale=False,
                   nms=True):
        assert len(cls_scores) == len(pts_preds_refine)
        num_levels = len(cls_scores)

        bbox_preds_refine = [self.transform_box(pts_pred_refine) for pts_pred_refine in pts_preds_refine]
        mlvl_points = [
            self.point_generators[i].grid_points(cls_scores[i].size()[-2:],
                                                 self.point_strides[i])
            for i in range(num_levels)
        ]
        result_list = []
        for img_id in range(len(img_metas)):
            cls_score_list = [
                cls_scores[i][img_id].detach() for i in range(num_levels)
            ]
            hm_scores_list = [
                hm_scores[i][img_id].detach() for i in range(num_levels)
            ]
            hm_offsets_list = [
                hm_offsets[i][img_id].detach() for i in range(num_levels)
            ]
            bbox_pred_list = [
                bbox_preds_refine[i][img_id].detach() for i in range(num_levels)
            ]
            img_shape = img_metas[img_id]['img_shape']
            scale_factor = img_metas[img_id]['scale_factor']
            proposals = self.get_bboxes_single(cls_score_list, hm_scores_list, hm_offsets_list, bbox_pred_list,
                                               mlvl_points, img_shape, scale_factor, cfg, rescale, nms)
            result_list.append(proposals)
        return result_list

    def get_bboxes_single(self,
                          cls_scores,
                          hm_scores,
                          hm_offsets,
                          bbox_preds,
                          mlvl_points,
                          img_shape,
                          scale_factor,
                          cfg,
                          rescale=False,
                          nms=True):
        def select(score_map, x, y, ks=2, i=0):
            H, W = score_map.shape[-2], score_map.shape[-1]
            score_map = score_map.sigmoid()
            score_map_original = score_map.clone()

            score_map, indices = F.max_pool2d_with_indices(score_map.unsqueeze(0), kernel_size=ks, stride=1, padding=(ks - 1) // 2)

            indices = indices.squeeze(0).squeeze(0)

            if ks % 2 == 0:
                round_func = torch.floor
            else:
                round_func = torch.round

            x_round = round_func((x / self.point_strides[i]).clamp(min=0, max=score_map.shape[-1] - 1))
            y_round = round_func((y / self.point_strides[i]).clamp(min=0, max=score_map.shape[-2] - 1))

            select_indices = indices[y_round.to(torch.long), x_round.to(torch.long)]
            new_x = select_indices % W
            new_y = select_indices // W

            score_map_squeeze = score_map_original.squeeze(0)
            score = score_map_squeeze[new_y, new_x]

            new_x, new_y = new_x.to(torch.float), new_y.to(torch.float)

            return new_x, new_y, score

        assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
        mlvl_bboxes = []
        mlvl_scores = []
        for i_lvl, (cls_score, bbox_pred, points) in enumerate(zip(cls_scores, bbox_preds, mlvl_points)):
            assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
            scores = cls_score.permute(1, 2, 0).reshape(
                -1, self.cls_out_channels).sigmoid()
            bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
            nms_pre = cfg.get('nms_pre', -1)
            if nms_pre > 0 and scores.shape[0] > nms_pre:
                max_scores, _ = scores.max(dim=1)
                _, topk_inds = max_scores.topk(nms_pre)
                points = points[topk_inds, :]
                bbox_pred = bbox_pred[topk_inds, :]
                scores = scores[topk_inds, :]
            bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
            bboxes = bbox_pred * self.point_strides[i_lvl] + bbox_pos_center
            x1 = bboxes[:, 0].clamp(min=0, max=img_shape[1])
            y1 = bboxes[:, 1].clamp(min=0, max=img_shape[0])
            x2 = bboxes[:, 2].clamp(min=0, max=img_shape[1])
            y2 = bboxes[:, 3].clamp(min=0, max=img_shape[0])

            if i_lvl > 0:
                i = 0 if i_lvl in (1, 2) else 1

                x1_new, y1_new, score1_new = select(hm_scores[i][0, ...], x1, y1, 2, i)
                x2_new, y2_new, score2_new = select(hm_scores[i][1, ...], x2, y2, 2, i)

                hm_offset = hm_offsets[i].permute(1, 2, 0)
                point_stride = self.point_strides[i]

                x1 = ((x1_new + hm_offset[y1_new.to(torch.long), x1_new.to(torch.long), 0]) * point_stride).clamp(min=0, max=img_shape[1])
                y1 = ((y1_new + hm_offset[y1_new.to(torch.long), x1_new.to(torch.long), 1]) * point_stride).clamp(min=0, max=img_shape[0])
                x2 = ((x2_new + hm_offset[y2_new.to(torch.long), x2_new.to(torch.long), 2]) * point_stride).clamp(min=0, max=img_shape[1])
                y2 = ((y2_new + hm_offset[y2_new.to(torch.long), x2_new.to(torch.long), 3]) * point_stride).clamp(min=0, max=img_shape[0])

            bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
            mlvl_bboxes.append(bboxes)
            mlvl_scores.append(scores)
        mlvl_bboxes = torch.cat(mlvl_bboxes)

        if rescale:
            mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
        mlvl_scores = torch.cat(mlvl_scores)
        padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
        mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
        if nms:
            det_bboxes, det_labels = multiclass_nms(
                mlvl_bboxes, mlvl_scores, cfg.score_thr, cfg.nms, cfg.max_per_img)
            return det_bboxes, det_labels
        else:
            return mlvl_bboxes, mlvl_scores
