import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import normal_init

from .anchor_head import AnchorHead
from ..registry import HEADS
from ..utils import bias_init_with_prob, ConvModule
from mmdet.core import (AnchorGenerator, anchor_target, delta2bbox,
                        multi_apply, multiclass_nms)
from ..builder import build_loss
from ..registry import HEADS


@HEADS.register_module
class RetinaHead(AnchorHead):

    def __init__(self,
                 num_classes,
                 in_channels,
                 stacked_convs=4,
                 octave_base_scale=4,
                 scales_per_octave=3,
                 conv_cfg=None,
                 norm_cfg=None,
                 **kwargs):
        self.stacked_convs = stacked_convs
        self.octave_base_scale = octave_base_scale
        self.scales_per_octave = scales_per_octave
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        octave_scales = np.array(
            [2**(i / scales_per_octave) for i in range(scales_per_octave)])
        anchor_scales = octave_scales * octave_base_scale
        super(RetinaHead, self).__init__(
            num_classes, in_channels, anchor_scales=anchor_scales, **kwargs)

    def _init_layers(self):
        self.relu = nn.ReLU(inplace=True)
        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()
        for i in range(self.stacked_convs):
            chn = self.in_channels if i == 0 else self.feat_channels
            self.cls_convs.append(
                ConvModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg))
            self.reg_convs.append(
                ConvModule(
                    chn,
                    self.feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    conv_cfg=self.conv_cfg,
                    norm_cfg=self.norm_cfg))
        self.retina_cls = nn.Conv2d(
            self.feat_channels,
            self.num_anchors * self.cls_out_channels,
            3,
            padding=1)
        self.retina_reg = nn.Conv2d(
            self.feat_channels, self.num_anchors * 4, 3, padding=1)

    def init_weights(self):
        for m in self.cls_convs:
            normal_init(m.conv, std=0.01)
        for m in self.reg_convs:
            normal_init(m.conv, std=0.01)
        bias_cls = bias_init_with_prob(0.01)
        normal_init(self.retina_cls, std=0.01, bias=bias_cls)
        normal_init(self.retina_reg, std=0.01)

    def forward_single(self, x):
        cls_feat = x
        reg_feat = x
        for cls_conv in self.cls_convs:
            cls_feat = cls_conv(cls_feat)
        for reg_conv in self.reg_convs:
            reg_feat = reg_conv(reg_feat)
        cls_score = self.retina_cls(cls_feat)
        bbox_pred = self.retina_reg(reg_feat)
        return cls_score, bbox_pred

    # def get_bboxes(self, cls_scores, bbox_preds, img_metas, cfg,
    #                rescale=False):
    #     assert len(cls_scores) == len(bbox_preds)
    #     num_levels = len(cls_scores)
    #
    #     mlvl_anchors = [
    #         self.anchor_generators[i].grid_anchors(cls_scores[i].size()[-2:],
    #                                                self.anchor_strides[i])
    #         for i in range(num_levels)
    #     ]
    #     result_list = []
    #     for img_id in range(len(img_metas)):
    #         cls_score_list = [
    #             cls_scores[i][img_id].detach() for i in range(num_levels)
    #         ]
    #         bbox_pred_list = [
    #             bbox_preds[i][img_id].detach() for i in range(num_levels)
    #         ]
    #         img_shape = img_metas[img_id]['img_shape']
    #         scale_factor = img_metas[img_id]['scale_factor']
    #         proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
    #                                            mlvl_anchors, img_shape,
    #                                            scale_factor, cfg, rescale)
    #         result_list.append(proposals)
    #     return result_list
    #
    #
    # def get_bboxes_single(self,
    #                       cls_scores,
    #                       bbox_preds,
    #                       mlvl_anchors,
    #                       img_shape,
    #                       scale_factor,
    #                       cfg,
    #                       rescale=False):
    #     assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
    #     mlvl_bboxes = []
    #     mlvl_scores = []
    #     for cls_score, bbox_pred, anchors in zip(cls_scores, bbox_preds,
    #                                              mlvl_anchors):
    #         # assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
    #         # cls_score = cls_score.permute(1, 2,
    #         #                               0).reshape(-1, self.cls_out_channels)
    #         # if self.use_sigmoid_cls:
    #         #     scores = cls_score.sigmoid()
    #         # else:
    #         #     scores = cls_score.softmax(-1)
    #         # bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
    #         # nms_pre = cfg.get('nms_pre', -1)
    #         # if nms_pre > 0 and scores.shape[0] > nms_pre:
    #         #     if self.use_sigmoid_cls:
    #         #         max_scores, _ = scores.max(dim=1)
    #         #     else:
    #         #         max_scores, _ = scores[:, 1:].max(dim=1)
    #         #     _, topk_inds = max_scores.topk(nms_pre)
    #         #     anchors = anchors[topk_inds, :]
    #         #     bbox_pred = bbox_pred[topk_inds, :]
    #         #     scores = scores[topk_inds, :]
    #         # bboxes = delta2bbox(anchors, bbox_pred, self.target_means,
    #         #                     self.target_stds, img_shape)
    #         mlvl_bboxes.append(anchors)
    #         #mlvl_scores.append(scores)
    #     mlvl_bboxes = torch.cat(mlvl_bboxes)
    #     # if rescale:
    #     #     mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
    #     # mlvl_scores = torch.cat(mlvl_scores)
    #     # if self.use_sigmoid_cls:
    #     #     padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
    #     #     mlvl_scores = torch.cat([padding, mlvl_scores], dim=1)
    #     # det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
    #     #                                         cfg.score_thr, cfg.nms,
    #     #                                         cfg.max_per_img)
    #     # return det_bboxes, det_labels
    #     return mlvl_bboxes