Skip to content

llmcompressor.modifiers.modifier

Modifier

Bases: ModifierInterface, HooksMixin

A base class for all modifiers to inherit from. Modifiers are used to modify the training process for a model. Defines base attributes and methods available to all modifiers

Lifecycle: 1. initialize 2. on_event -> * on_start if self.start <= event.current_index * on_end if self.end >= event.current_index 5. finalize

Parameters:

Name Type Description Default
index

The index of the modifier in the list of modifiers for the model

required
group

The group name for the modifier

required
start

The start step for the modifier

required
end

The end step for the modifier

required
update

The update step for the modifier

required
Source code in src/llmcompressor/modifiers/modifier.py
class Modifier(ModifierInterface, HooksMixin):
    """
    A base class for all modifiers to inherit from.
    Modifiers are used to modify the training process for a model.
    Defines base attributes and methods available to all modifiers

    Lifecycle:
    1. initialize
    2. on_event ->
        * on_start if self.start <= event.current_index
        * on_end if self.end >= event.current_index
    5. finalize

    :param index: The index of the modifier in the list of modifiers
        for the model
    :param group: The group name for the modifier
    :param start: The start step for the modifier
    :param end: The end step for the modifier
    :param update: The update step for the modifier
    """

    index: Optional[int] = None
    group: Optional[str] = None
    start: Optional[float] = None
    end: Optional[float] = None
    update: Optional[float] = None

    initialized_: bool = False
    finalized_: bool = False
    started_: bool = False
    ended_: bool = False

    @property
    def initialized(self) -> bool:
        """
        :return: True if the modifier has been initialized
        """
        return self.initialized_

    @property
    def finalized(self) -> bool:
        """
        :return: True if the modifier has been finalized
        """
        return self.finalized_

    def initialize(self, state: State, **kwargs):
        """
        Initialize the modifier for the given model and state.

        :raises RuntimeError: if the modifier has already been finalized
        :param state: The current state of the model
        :param kwargs: Additional arguments for initializing the modifier
        """
        if self.initialized_:
            raise RuntimeError(
                "Cannot initialize a modifier that has already been initialized"
            )

        if self.finalized_:
            raise RuntimeError(
                "Cannot initialize a modifier that has already been finalized"
            )

        self.initialized_ = self.on_initialize(state=state, **kwargs)

        # trigger starts
        fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
        if self.should_start(fake_start_event):
            self.on_start(state, fake_start_event, **kwargs)
            self.started_ = True

    def finalize(self, state: State, **kwargs):
        """
        Finalize the modifier for the given model and state.

        :raises RuntimeError: if the modifier has not been initialized
        :param state: The current state of the model
        :param kwargs: Additional arguments for finalizing the modifier
        """
        if self.finalized_:
            raise RuntimeError("cannot finalize a modifier twice")

        if not self.initialized_:
            raise RuntimeError("cannot finalize an uninitialized modifier")

        # TODO: all finalization should succeed
        self.finalized_ = self.on_finalize(state=state, **kwargs)

    def update_event(self, state: State, event: Event, **kwargs):
        """
        Update modifier based on the given event. In turn calls
        on_start, on_update, and on_end based on the event and
        modifier settings. Returns immediately if the modifier is
        not initialized

        :raises RuntimeError: if the modifier has been finalized
        :param state: The current state of sparsification
        :param event: The event to update the modifier with
        :param kwargs: Additional arguments for updating the modifier
        """
        if not self.initialized_:
            raise RuntimeError("Cannot update an uninitialized modifier")

        if self.finalized_:
            raise RuntimeError("Cannot update a finalized modifier")

        self.on_event(state, event, **kwargs)

        # handle starting the modifier if needed
        if (
            event.type_ == EventType.BATCH_START
            and not self.started_
            and self.should_start(event)
        ):
            self.on_start(state, event, **kwargs)
            self.started_ = True
            self.on_update(state, event, **kwargs)

            return

        # handle ending the modifier if needed
        if (
            event.type_ == EventType.BATCH_END
            and not self.ended_
            and self.should_end(event)
        ):
            self.on_end(state, event, **kwargs)
            self.ended_ = True
            self.on_update(state, event, **kwargs)

            return

        if self.started_ and not self.ended_:
            self.on_update(state, event, **kwargs)

    def should_start(self, event: Event) -> bool:
        """
        :param event: The event to check if the modifier should start
        :return: True if the modifier should start based on the given event
        """
        if self.start is None:
            return False

        current = event.current_index

        return self.start <= current and (self.end is None or current < self.end)

    def should_end(self, event: Event):
        """
        :param event: The event to check if the modifier should end
        :return: True if the modifier should end based on the given event
        """
        current = event.current_index

        return self.end is not None and current >= self.end

    @abstractmethod
    def on_initialize(self, state: State, **kwargs) -> bool:
        """
        on_initialize is called on modifier initialization and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param kwargs: Additional arguments for initializing the modifier
        :return: True if the modifier was initialized successfully,
            False otherwise
        """
        raise NotImplementedError()

    def on_finalize(self, state: State, **kwargs) -> bool:
        """
        on_finalize is called on modifier finalization and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param kwargs: Additional arguments for finalizing the modifier
        :return: True if the modifier was finalized successfully,
            False otherwise
        """
        return True

    def on_start(self, state: State, event: Event, **kwargs):
        """
        on_start is called when the modifier starts and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the start
        :param kwargs: Additional arguments for starting the modifier
        """
        pass

    def on_update(self, state: State, event: Event, **kwargs):
        """
        on_update is called when the model in question must be
        updated based on passed in event. Must be implemented by the
        inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the update
        :param kwargs: Additional arguments for updating the model
        """
        pass

    def on_end(self, state: State, event: Event, **kwargs):
        """
        on_end is called when the modifier ends and must be implemented
        by the inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the end
        :param kwargs: Additional arguments for ending the modifier
        """
        pass

    def on_event(self, state: State, event: Event, **kwargs):
        """
        on_event is called whenever an event is triggered

        :param state: The current state of the model
        :param event: The event that triggered the update
        :param kwargs: Additional arguments for updating the model
        """
        pass

finalized property

Returns:

Type Description
bool

True if the modifier has been finalized

initialized property

Returns:

Type Description
bool

True if the modifier has been initialized

finalize(state, **kwargs)

Finalize the modifier for the given model and state.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for finalizing the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has not been initialized

Source code in src/llmcompressor/modifiers/modifier.py
def finalize(self, state: State, **kwargs):
    """
    Finalize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has not been initialized
    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    """
    if self.finalized_:
        raise RuntimeError("cannot finalize a modifier twice")

    if not self.initialized_:
        raise RuntimeError("cannot finalize an uninitialized modifier")

    # TODO: all finalization should succeed
    self.finalized_ = self.on_finalize(state=state, **kwargs)

initialize(state, **kwargs)

Initialize the modifier for the given model and state.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for initializing the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has already been finalized

Source code in src/llmcompressor/modifiers/modifier.py
def initialize(self, state: State, **kwargs):
    """
    Initialize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has already been finalized
    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    """
    if self.initialized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been initialized"
        )

    if self.finalized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been finalized"
        )

    self.initialized_ = self.on_initialize(state=state, **kwargs)

    # trigger starts
    fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
    if self.should_start(fake_start_event):
        self.on_start(state, fake_start_event, **kwargs)
        self.started_ = True

on_end(state, event, **kwargs)

on_end is called when the modifier ends and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the end

required
kwargs

Additional arguments for ending the modifier

{}
Source code in src/llmcompressor/modifiers/modifier.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    on_end is called when the modifier ends and must be implemented
    by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the end
    :param kwargs: Additional arguments for ending the modifier
    """
    pass

on_event(state, event, **kwargs)

on_event is called whenever an event is triggered

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the update

required
kwargs

Additional arguments for updating the model

{}
Source code in src/llmcompressor/modifiers/modifier.py
def on_event(self, state: State, event: Event, **kwargs):
    """
    on_event is called whenever an event is triggered

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

on_finalize(state, **kwargs)

on_finalize is called on modifier finalization and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for finalizing the modifier

{}

Returns:

Type Description
bool

True if the modifier was finalized successfully, False otherwise

Source code in src/llmcompressor/modifiers/modifier.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    on_finalize is called on modifier finalization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    :return: True if the modifier was finalized successfully,
        False otherwise
    """
    return True

on_initialize(state, **kwargs) abstractmethod

on_initialize is called on modifier initialization and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for initializing the modifier

{}

Returns:

Type Description
bool

True if the modifier was initialized successfully, False otherwise

Source code in src/llmcompressor/modifiers/modifier.py
@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    on_initialize is called on modifier initialization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    :return: True if the modifier was initialized successfully,
        False otherwise
    """
    raise NotImplementedError()

on_start(state, event, **kwargs)

on_start is called when the modifier starts and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the start

required
kwargs

Additional arguments for starting the modifier

{}
Source code in src/llmcompressor/modifiers/modifier.py
def on_start(self, state: State, event: Event, **kwargs):
    """
    on_start is called when the modifier starts and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the start
    :param kwargs: Additional arguments for starting the modifier
    """
    pass

on_update(state, event, **kwargs)

on_update is called when the model in question must be updated based on passed in event. Must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the update

required
kwargs

Additional arguments for updating the model

{}
Source code in src/llmcompressor/modifiers/modifier.py
def on_update(self, state: State, event: Event, **kwargs):
    """
    on_update is called when the model in question must be
    updated based on passed in event. Must be implemented by the
    inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

should_end(event)

Parameters:

Name Type Description Default
event Event

The event to check if the modifier should end

required

Returns:

Type Description

True if the modifier should end based on the given event

Source code in src/llmcompressor/modifiers/modifier.py
def should_end(self, event: Event):
    """
    :param event: The event to check if the modifier should end
    :return: True if the modifier should end based on the given event
    """
    current = event.current_index

    return self.end is not None and current >= self.end

should_start(event)

Parameters:

Name Type Description Default
event Event

The event to check if the modifier should start

required

Returns:

Type Description
bool

True if the modifier should start based on the given event

Source code in src/llmcompressor/modifiers/modifier.py
def should_start(self, event: Event) -> bool:
    """
    :param event: The event to check if the modifier should start
    :return: True if the modifier should start based on the given event
    """
    if self.start is None:
        return False

    current = event.current_index

    return self.start <= current and (self.end is None or current < self.end)

update_event(state, event, **kwargs)

Update modifier based on the given event. In turn calls on_start, on_update, and on_end based on the event and modifier settings. Returns immediately if the modifier is not initialized

Parameters:

Name Type Description Default
state State

The current state of sparsification

required
event Event

The event to update the modifier with

required
kwargs

Additional arguments for updating the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has been finalized

Source code in src/llmcompressor/modifiers/modifier.py
def update_event(self, state: State, event: Event, **kwargs):
    """
    Update modifier based on the given event. In turn calls
    on_start, on_update, and on_end based on the event and
    modifier settings. Returns immediately if the modifier is
    not initialized

    :raises RuntimeError: if the modifier has been finalized
    :param state: The current state of sparsification
    :param event: The event to update the modifier with
    :param kwargs: Additional arguments for updating the modifier
    """
    if not self.initialized_:
        raise RuntimeError("Cannot update an uninitialized modifier")

    if self.finalized_:
        raise RuntimeError("Cannot update a finalized modifier")

    self.on_event(state, event, **kwargs)

    # handle starting the modifier if needed
    if (
        event.type_ == EventType.BATCH_START
        and not self.started_
        and self.should_start(event)
    ):
        self.on_start(state, event, **kwargs)
        self.started_ = True
        self.on_update(state, event, **kwargs)

        return

    # handle ending the modifier if needed
    if (
        event.type_ == EventType.BATCH_END
        and not self.ended_
        and self.should_end(event)
    ):
        self.on_end(state, event, **kwargs)
        self.ended_ = True
        self.on_update(state, event, **kwargs)

        return

    if self.started_ and not self.ended_:
        self.on_update(state, event, **kwargs)