import torch

from ..bbox import build_assigner, PseudoSampler
from ..utils import multi_apply


def point_target(proposals_list,
                 valid_flag_list,
                 gt_bboxes_list,
                 gt_labels_list,
                 cfg,
                 unmap_outputs=True):
    """Compute refinement and classification targets for points.

    Args:
        points_list (list[list]): Multi level points of each image.
        valid_flag_list (list[list]): Multi level valid flags of each image.
        gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
        img_metas (list[dict]): Meta info of each image.
        cfg (dict): train sample configs.

    Returns:
        tuple
    """
    num_imgs = len(proposals_list)
    assert len(proposals_list) == len(valid_flag_list)

    # points number of multi levels
    num_level_proposals = [points.size(0) for points in proposals_list[0]]
    num_level_proposals_list = [num_level_proposals] * num_imgs

    # concat all level points and flags to a single tensor
    for i in range(len(proposals_list)):
        assert len(proposals_list[i]) == len(valid_flag_list[i])
        proposals_list[i] = torch.cat(proposals_list[i])
        valid_flag_list[i] = torch.cat(valid_flag_list[i])

    (all_labels, all_label_weights, all_bbox_gt, all_bbox_weights,
     pos_inds_list, neg_inds_list) = multi_apply(
        point_target_single,
        proposals_list,
        valid_flag_list,
        num_level_proposals_list,
        gt_bboxes_list,
        gt_labels_list,
        cfg=cfg,
        unmap_outputs=unmap_outputs)

    # sampled points of all images
    num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
    num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
    labels_list = images_to_levels(all_labels, num_level_proposals)
    label_weights_list = images_to_levels(all_label_weights, num_level_proposals)
    bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
    bbox_weights_list = images_to_levels(all_bbox_weights, num_level_proposals)
    return (labels_list, label_weights_list, bbox_gt_list, bbox_weights_list,
            num_total_pos, num_total_neg)


def images_to_levels(target, num_level_grids):
    """Convert targets by image to targets by feature level.

    [target_img0, target_img1] -> [target_level0, target_level1, ...]
    """
    target = torch.stack(target, 0)
    level_targets = []
    start = 0
    for n in num_level_grids:
        end = start + n
        level_targets.append(target[:, start:end].squeeze(0))
        start = end
    return level_targets


def point_target_single(flat_proposals,
                        inside_flags,
                        num_level_proposals,
                        gt_bboxes,
                        gt_labels,
                        cfg,
                        unmap_outputs=True):
    # assign gt and sample points
    proposals = flat_proposals[inside_flags, :]

    num_level_proposals_inside = get_num_level_proposals_inside(num_level_proposals, inside_flags)
    bbox_assigner = build_assigner(cfg.assigner)
    if cfg.assigner.type != "ATSSAssigner":
        assign_result = bbox_assigner.assign(proposals, gt_bboxes, None, gt_labels)
    else:
        assign_result = bbox_assigner.assign(proposals, num_level_proposals_inside, gt_bboxes, None, gt_labels)
    bbox_sampler = PseudoSampler()
    sampling_result = bbox_sampler.sample(assign_result, proposals, gt_bboxes)

    num_valid_proposals = proposals.shape[0]
    bbox_gt = proposals.new_zeros([num_valid_proposals, 4])
    bbox_weights = proposals.new_zeros([num_valid_proposals, 4])
    labels = proposals.new_zeros(num_valid_proposals, dtype=torch.long)
    label_weights = proposals.new_zeros(num_valid_proposals, dtype=torch.float)

    pos_inds = sampling_result.pos_inds
    neg_inds = sampling_result.neg_inds
    if len(pos_inds) > 0:
        pos_gt_bboxes = sampling_result.pos_gt_bboxes
        bbox_gt[pos_inds, :] = pos_gt_bboxes
        bbox_weights[pos_inds, :] = 1.0
        if gt_labels is None:
            labels[pos_inds] = 1
        else:
            labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
        if cfg.pos_weight <= 0:
            label_weights[pos_inds] = 1.0
        else:
            label_weights[pos_inds] = cfg.pos_weight
    if len(neg_inds) > 0:
        label_weights[neg_inds] = 1.0

    # map up to original set of grids
    if unmap_outputs:
        num_total_proposals = flat_proposals.size(0)
        labels = unmap(labels, num_total_proposals, inside_flags)
        label_weights = unmap(label_weights, num_total_proposals, inside_flags)
        bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
        bbox_weights = unmap(bbox_weights, num_total_proposals, inside_flags)

    return labels, label_weights, bbox_gt, bbox_weights, pos_inds, neg_inds


def unmap(data, count, inds, fill=0):
    """ Unmap a subset of item (data) back to the original set of items (of
    size count) """
    if data.dim() == 1:
        ret = data.new_full((count,), fill)
        ret[inds] = data
    else:
        new_size = (count,) + data.size()[1:]
        ret = data.new_full(new_size, fill)
        ret[inds, :] = data
    return ret


def get_num_level_proposals_inside(num_level_proposals, inside_flags):
    split_inside_flags = torch.split(inside_flags, num_level_proposals)
    num_level_proposals_inside = [
        int(flags.sum()) for flags in split_inside_flags
    ]
    return num_level_proposals_inside
