Skip to content

llmcompressor.modifiers.utils.hooks

HooksMixin

Bases: BaseModel

Mixin to manage hook registration, disabling, and removal. Modifiers should use self.register_hook(module, hook, hook_type) for hook registration and self.remove_hooks() for removal.

Modifiers which implement hooks should register them using self.register_..._hook(module, hook) rather than the usual module.register_..._hook(hook). Modifiers should remove hooks with self.remove_hooks().

Hooks can be applied to modules or parameters

Typical example

modifier.register_forward_hook(module, hook) with HooksMixin.disable_hooks(): model.forward(...) modifier.remove_hooks()

Example of activating only a specific subset of hooks

hooks = [modifier.register_forward_hook(module, hook) for module in ...] with HooksMixin.disable_hooks(keep=hooks): model.forward(...) modifier.remove_hooks(hooks)

Source code in src/llmcompressor/modifiers/utils/hooks.py
class HooksMixin(BaseModel):
    """
    Mixin to manage hook registration, disabling, and removal.
    Modifiers should use `self.register_hook(module, hook, hook_type)`
    for hook registration and `self.remove_hooks()` for removal.

    Modifiers which implement hooks should register them using
    `self.register_..._hook(module, hook)` rather than the usual
    `module.register_..._hook(hook)`. Modifiers should remove hooks with
    `self.remove_hooks()`.

    Hooks can be applied to modules or parameters

    Typical example
    >>> modifier.register_forward_hook(module, hook)
    >>> with HooksMixin.disable_hooks():
            model.forward(...)
    >>> modifier.remove_hooks()

    Example of activating only a specific subset of hooks
    >>> hooks = [modifier.register_forward_hook(module, hook) for module in ...]
    >>> with HooksMixin.disable_hooks(keep=hooks):
            model.forward(...)
    >>> modifier.remove_hooks(hooks)
    """

    # attached to global HooksMixin class
    _HOOKS_DISABLED: ClassVar[bool] = False
    _HOOKS_KEEP_ENABLED: ClassVar[Set[RemovableHandle]] = set()

    # attached to local subclasses
    _hooks: Set[RemovableHandle] = set()

    @classmethod
    @contextlib.contextmanager
    def disable_hooks(cls, keep: Set[RemovableHandle] = frozenset()):
        """
        Disable all hooks across all modifiers. Composing multiple contexts is
        equivalent to the union of `keep` arguments

        :param keep: optional set of handles to keep enabled
        """
        try:
            cls._HOOKS_DISABLED = True
            cls._HOOKS_KEEP_ENABLED |= keep
            yield
        finally:
            cls._HOOKS_DISABLED = False
            cls._HOOKS_KEEP_ENABLED -= keep

    def register_hook(
        self,
        target: Union[torch.nn.Module, torch.nn.Parameter],
        hook: Callable[[Any], Any],
        hook_type: str,
        **kwargs,
    ) -> RemovableHandle:
        """
        Registers a hook on a specified module/parameter with the option to disable it
        with HooksMixin.disable_hooks()

        :param target: the module or parameter on which the hook should be registered
        :param hook: the hook to register
        :param hook_type: the type of hook to register corresponding to the
            `register_{hook_type}_hook` attribute on torch.nn.Module.
            Ex. "forward", "forward_pre", "full_backward", "state_dict_post", ""
        :param kwargs: keyword arguments to pass to register hook method
        """
        handle = None

        @wraps(hook)
        def wrapped_hook(*args, **kwargs):
            nonlocal handle

            if (
                HooksMixin._HOOKS_DISABLED
                and handle not in HooksMixin._HOOKS_KEEP_ENABLED
            ):
                return

            return hook(*args, **kwargs)

        register_function = getattr(target, f"register_{hook_type}_hook")
        handle = register_function(wrapped_hook, **kwargs)
        self._hooks.add(handle)
        logger.debug(f"{self} added {handle}")

        return handle

    def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
        """
        Removes hooks registered by this modifier

        :param handles: optional list of handles to remove, defaults to all hooks
            registerd by this modifier
        """
        if handles is None:
            handles = self._hooks

        for hook in handles:
            hook.remove()

        self._hooks -= handles

disable_hooks(keep=frozenset()) classmethod

Disable all hooks across all modifiers. Composing multiple contexts is equivalent to the union of keep arguments

Parameters:

Name Type Description Default
keep Set[RemovableHandle]

optional set of handles to keep enabled

frozenset()
Source code in src/llmcompressor/modifiers/utils/hooks.py
@classmethod
@contextlib.contextmanager
def disable_hooks(cls, keep: Set[RemovableHandle] = frozenset()):
    """
    Disable all hooks across all modifiers. Composing multiple contexts is
    equivalent to the union of `keep` arguments

    :param keep: optional set of handles to keep enabled
    """
    try:
        cls._HOOKS_DISABLED = True
        cls._HOOKS_KEEP_ENABLED |= keep
        yield
    finally:
        cls._HOOKS_DISABLED = False
        cls._HOOKS_KEEP_ENABLED -= keep

register_hook(target, hook, hook_type, **kwargs)

Registers a hook on a specified module/parameter with the option to disable it with HooksMixin.disable_hooks()

Parameters:

Name Type Description Default
target Union[Module, Parameter]

the module or parameter on which the hook should be registered

required
hook Callable[[Any], Any]

the hook to register

required
hook_type str

the type of hook to register corresponding to the register_{hook_type}_hook attribute on torch.nn.Module. Ex. "forward", "forward_pre", "full_backward", "state_dict_post", ""

required
kwargs

keyword arguments to pass to register hook method

{}
Source code in src/llmcompressor/modifiers/utils/hooks.py
def register_hook(
    self,
    target: Union[torch.nn.Module, torch.nn.Parameter],
    hook: Callable[[Any], Any],
    hook_type: str,
    **kwargs,
) -> RemovableHandle:
    """
    Registers a hook on a specified module/parameter with the option to disable it
    with HooksMixin.disable_hooks()

    :param target: the module or parameter on which the hook should be registered
    :param hook: the hook to register
    :param hook_type: the type of hook to register corresponding to the
        `register_{hook_type}_hook` attribute on torch.nn.Module.
        Ex. "forward", "forward_pre", "full_backward", "state_dict_post", ""
    :param kwargs: keyword arguments to pass to register hook method
    """
    handle = None

    @wraps(hook)
    def wrapped_hook(*args, **kwargs):
        nonlocal handle

        if (
            HooksMixin._HOOKS_DISABLED
            and handle not in HooksMixin._HOOKS_KEEP_ENABLED
        ):
            return

        return hook(*args, **kwargs)

    register_function = getattr(target, f"register_{hook_type}_hook")
    handle = register_function(wrapped_hook, **kwargs)
    self._hooks.add(handle)
    logger.debug(f"{self} added {handle}")

    return handle

remove_hooks(handles=None)

Removes hooks registered by this modifier

Parameters:

Name Type Description Default
handles Optional[Set[RemovableHandle]]

optional list of handles to remove, defaults to all hooks registerd by this modifier

None
Source code in src/llmcompressor/modifiers/utils/hooks.py
def remove_hooks(self, handles: Optional[Set[RemovableHandle]] = None):
    """
    Removes hooks registered by this modifier

    :param handles: optional list of handles to remove, defaults to all hooks
        registerd by this modifier
    """
    if handles is None:
        handles = self._hooks

    for hook in handles:
        hook.remove()

    self._hooks -= handles