Skip to content

llmcompressor.modifiers.stage

StageModifiers

Bases: ModifierInterface, BaseModel

Represents a collection of modifiers that are applied together as a stage.

Parameters:

Name Type Description Default
modifiers

The modifiers to apply as a stage

required
index

The index of the stage, if applicable

required
group

The group name of the stage, if applicable

required
applied

Flag for indicating if this stage has has already been applied to the model through finalization

required
Source code in src/llmcompressor/modifiers/stage.py
class StageModifiers(ModifierInterface, BaseModel):
    """
    Represents a collection of modifiers that are applied together as a stage.

    :param modifiers: The modifiers to apply as a stage
    :param index: The index of the stage, if applicable
    :param group: The group name of the stage, if applicable
    :param applied: Flag for indicating if this stage has has already been
    applied to the model through finalization
    """

    modifiers: List["Modifier"] = Field(default_factory=list)
    index: Optional[int] = None
    group: Optional[str] = None
    applied: bool = False

    @property
    def initialized(self) -> bool:
        """
        :return: True if all of the stage modifiers have been initialized,
            False otherwise
        """
        return all(mod.initialized for mod in self.modifiers)

    @property
    def finalized(self) -> bool:
        """
        :return: True if all of the stage modifiers have been finalized,
            False otherwise
        """
        return all(mod.finalized for mod in self.modifiers)

    @property
    def unique_id(self) -> str:
        """
        :return: ID for stage containing the name and index
        """
        return self.group + "_" + str(self.index)

    def initialize(self, state: "State", **kwargs):
        """
        Initialize all the stage modifiers

        :param state: The state of current session
        :param kwargs: Additional kwargs to pass to the modifier(s)
            initialize method
        """

        if self.applied:
            return

        accelerator = kwargs.get("accelerator", None)
        for modifier in self.modifiers:
            if not modifier.initialized:
                modifier.initialize(state, **kwargs)
            if accelerator:
                accelerator.wait_for_everyone()
        state.loggers.system.info(tag="stage", string="Modifiers initialized")

    def finalize(self, state: "State", **kwargs):
        """
        Finalize all the stage modifiers and mark the stage as applied

        :param state: The state of current session
        :param kwargs: Additional kwargs to pass to the modifier(s)
            finalize method
        """

        if self.applied:
            return

        accelerator = kwargs.get("accelerator", None)
        for modifier in self.modifiers:
            modifier.finalize(state, **kwargs)
            if accelerator:
                accelerator.wait_for_everyone()

        self.applied = True
        state.loggers.system.info(tag="stage", string="Modifiers finalized")

    def update_event(self, state: "State", event: "Event", **kwargs):
        """
        Propagate the event to all the stage modifiers

        :param state: The state of current session
        :param event: The event to propagate
        :param kwargs: Additional kwargs to pass to the modifier(s)
            update_event method
        """

        if self.applied:
            return

        for modifier in self.modifiers:
            modifier.update_event(state, event, **kwargs)

finalized property

Returns:

Type Description
bool

True if all of the stage modifiers have been finalized, False otherwise

initialized property

Returns:

Type Description
bool

True if all of the stage modifiers have been initialized, False otherwise

unique_id property

Returns:

Type Description
str

ID for stage containing the name and index

finalize(state, **kwargs)

Finalize all the stage modifiers and mark the stage as applied

Parameters:

Name Type Description Default
state State

The state of current session

required
kwargs

Additional kwargs to pass to the modifier(s) finalize method

{}
Source code in src/llmcompressor/modifiers/stage.py
def finalize(self, state: "State", **kwargs):
    """
    Finalize all the stage modifiers and mark the stage as applied

    :param state: The state of current session
    :param kwargs: Additional kwargs to pass to the modifier(s)
        finalize method
    """

    if self.applied:
        return

    accelerator = kwargs.get("accelerator", None)
    for modifier in self.modifiers:
        modifier.finalize(state, **kwargs)
        if accelerator:
            accelerator.wait_for_everyone()

    self.applied = True
    state.loggers.system.info(tag="stage", string="Modifiers finalized")

initialize(state, **kwargs)

Initialize all the stage modifiers

Parameters:

Name Type Description Default
state State

The state of current session

required
kwargs

Additional kwargs to pass to the modifier(s) initialize method

{}
Source code in src/llmcompressor/modifiers/stage.py
def initialize(self, state: "State", **kwargs):
    """
    Initialize all the stage modifiers

    :param state: The state of current session
    :param kwargs: Additional kwargs to pass to the modifier(s)
        initialize method
    """

    if self.applied:
        return

    accelerator = kwargs.get("accelerator", None)
    for modifier in self.modifiers:
        if not modifier.initialized:
            modifier.initialize(state, **kwargs)
        if accelerator:
            accelerator.wait_for_everyone()
    state.loggers.system.info(tag="stage", string="Modifiers initialized")

update_event(state, event, **kwargs)

Propagate the event to all the stage modifiers

Parameters:

Name Type Description Default
state State

The state of current session

required
event Event

The event to propagate

required
kwargs

Additional kwargs to pass to the modifier(s) update_event method

{}
Source code in src/llmcompressor/modifiers/stage.py
def update_event(self, state: "State", event: "Event", **kwargs):
    """
    Propagate the event to all the stage modifiers

    :param state: The state of current session
    :param event: The event to propagate
    :param kwargs: Additional kwargs to pass to the modifier(s)
        update_event method
    """

    if self.applied:
        return

    for modifier in self.modifiers:
        modifier.update_event(state, event, **kwargs)