# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmrazor.registry import HOOKS


# Fine-grained stop mask learning in iter-based update
@HOOKS.register_module()
class StopMaskLearningIterHook(Hook):
    """Stop mask learning at a certain time.

    Args:
        stop_iter (int): Stop mask learning at this iter.
    """

    priority = 'HIGH'

    def __init__(self, stop_iter: int) -> None:
        self.stop_iter = stop_iter

    def before_train_iter(self, runner,
                          batch_idx: int,
                          data_batch=None) -> None:
        """Stop mask learning iter."""
        
        if runner.iter >= self.stop_iter:
            model = runner.model
            # TODO: refactor after mmengine using model wrapper
            if is_model_wrapper(model):
                model = model.module
            assert hasattr(model.distiller, 'mask_learning_stopped')

            if not model.distiller.mask_learning_stopped:
                runner.logger.info('Mask learning has been stopped!')
                model.distiller.mask_learning_stopped = True
