""" Reduce operators"""
from .base import *


class BoundReduceMax(Bound):
    def __init__(self, attr, inputs, output_index, options):
        super().__init__(attr, inputs, output_index, options)
        self.axis = attr['axes']
        # for torch.max, `dim` must be an int
        if isinstance(self.axis, list):
            assert len(self.axis) == 1
            self.axis = self.axis[0]
        self.keepdim = bool(attr['keepdims']) if 'keepdims' in attr else True
        self.use_default_ibp = True

        """Assume that the indexes with the maximum values are not perturbed.
        This generally doesn't hold true, but can still be used for the input shift
        in Softmax of Transformers."""
        self.fixed_max_index = options.get('fixed_reducemax_index', False)

    def forward(self, x):
        self.axis = self.make_axis_non_negative(self.axis)
        assert self.axis > 0
        res = torch.max(x, dim=self.axis, keepdim=self.keepdim)
        self.indices = res.indices
        return res.values

    def bound_backward(self, last_lA, last_uA, x):
        if self.fixed_max_index:
            def _bound_oneside(last_A):
                if last_A is None:
                    return None
                indices = self.indices.unsqueeze(0)
                if not self.keepdim:
                    assert (self.from_input)
                    last_A = last_A.unsqueeze(self.axis + 1)
                    indices = indices.unsqueeze(self.axis + 1)
                shape = list(last_A.shape)
                shape[self.axis + 1] *= self.input_shape[self.axis]
                A = torch.zeros(shape, device=last_A.device)
                A.scatter_(dim=self.axis + 1, index=indices, src=last_A)
                return A

            return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0
        else:
            raise NotImplementedError('`bound_backward` for BoundReduceMax with perturbed maximum indexes is not implemented.')


class BoundReduceMean(Bound):
    def __init__(self, attr, inputs, output_index, options):
        super().__init__(attr, inputs, output_index, options)
        self.axis = attr['axes']
        self.keepdim = bool(attr['keepdims']) if 'keepdims' in attr else True
        self.use_default_ibp = True

    def forward(self, x):
        return torch.mean(x, dim=self.axis, keepdim=self.keepdim)

    def bound_backward(self, last_lA, last_uA, x):
        for i in range(len(self.axis)):
            if self.axis[i] < 0:
                self.axis[i] = self.make_axis_non_negative(self.axis[i])
                assert self.axis[i] > 0

        def _bound_oneside(last_A):
            if last_A is None:
                return None
            if not self.keepdim:
                assert (self.from_input)
                for axis in self.axis:
                    if axis > 0:
                        last_A = last_A.unsqueeze(axis + 1)
            for axis in self.axis:
                shape = list(last_A.shape)
                size_axis = self.input_shape[axis]
                shape[axis + 1] *= size_axis
                last_A = last_A.expand(*shape) / size_axis
            return last_A

        return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0

    def bound_forward(self, dim_in, x):
        assert (self.keepdim)
        assert (len(self.axis) == 1)
        axis = self.make_axis_non_negative(self.axis[0])
        assert (axis > 0)
        size = self.input_shape[axis]
        lw = x.lw.sum(dim=axis + 1, keepdim=True) / size
        lb = x.lb.sum(dim=axis, keepdim=True) / size
        uw = x.uw.sum(dim=axis + 1, keepdim=True) / size
        ub = x.ub.sum(dim=axis, keepdim=True) / size
        return LinearBound(lw, lb, uw, ub)

class BoundReduceSum(Bound):
    def __init__(self, attr, inputs, output_index, options):
        super().__init__(attr, inputs, output_index, options)
        self.axis = attr['axes'] if 'axes' in attr else None
        self.keepdim = bool(attr['keepdims'])
        self.use_default_ibp = True

    def forward(self, x):
        if self.axis is not None:
            return torch.sum(x, dim=self.axis, keepdim=self.keepdim)
        else:
            return torch.sum(x)

    def bound_backward(self, last_lA, last_uA, x):
        for i in range(len(self.axis)):
            if self.axis[i] < 0:
                self.axis[i] = len(self.input_shape) + self.axis[i]
                assert self.axis[i] > 0

        def _bound_oneside(last_A):
            if last_A is None:
                return None
            if not self.keepdim:
                assert (self.from_input)
                for axis in self.axis:
                    if axis > 0:
                        last_A = last_A.unsqueeze(axis + 1)
            for axis in self.axis:
                shape = list(last_A.shape)
                shape[axis + 1] *= self.input_shape[axis]
                last_A = last_A.expand(*shape)
            return last_A

        return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0

    def bound_forward(self, dim_in, x):
        assert len(self.axis) == 1
        axis = self.make_axis_non_negative(self.axis[0])
        assert axis > 0
        lw = x.lw.sum(dim=axis + 1, keepdim=self.keepdim)
        lb = x.lb.sum(dim=axis, keepdim=self.keepdim)
        uw = x.uw.sum(dim=axis + 1, keepdim=self.keepdim)
        ub = x.ub.sum(dim=axis, keepdim=self.keepdim)
        return LinearBound(lw, lb, uw, ub)
