Skip to content

llmcompressor.modifiers

Modifier

Bases: ModifierInterface, HooksMixin

A base class for all modifiers to inherit from. Modifiers are used to modify the training process for a model. Defines base attributes and methods available to all modifiers

Lifecycle: 1. initialize 2. on_event -> * on_start if self.start <= event.current_index * on_end if self.end >= event.current_index 5. finalize

Parameters:

Name Type Description Default
index

The index of the modifier in the list of modifiers for the model

required
group

The group name for the modifier

required
start

The start step for the modifier

required
end

The end step for the modifier

required
update

The update step for the modifier

required
Source code in src/llmcompressor/modifiers/modifier.py
class Modifier(ModifierInterface, HooksMixin):
    """
    A base class for all modifiers to inherit from.
    Modifiers are used to modify the training process for a model.
    Defines base attributes and methods available to all modifiers

    Lifecycle:
    1. initialize
    2. on_event ->
        * on_start if self.start <= event.current_index
        * on_end if self.end >= event.current_index
    5. finalize

    :param index: The index of the modifier in the list of modifiers
        for the model
    :param group: The group name for the modifier
    :param start: The start step for the modifier
    :param end: The end step for the modifier
    :param update: The update step for the modifier
    """

    index: Optional[int] = None
    group: Optional[str] = None
    start: Optional[float] = None
    end: Optional[float] = None
    update: Optional[float] = None

    initialized_: bool = False
    finalized_: bool = False
    started_: bool = False
    ended_: bool = False

    @property
    def initialized(self) -> bool:
        """
        :return: True if the modifier has been initialized
        """
        return self.initialized_

    @property
    def finalized(self) -> bool:
        """
        :return: True if the modifier has been finalized
        """
        return self.finalized_

    def initialize(self, state: State, **kwargs):
        """
        Initialize the modifier for the given model and state.

        :raises RuntimeError: if the modifier has already been finalized
        :param state: The current state of the model
        :param kwargs: Additional arguments for initializing the modifier
        """
        if self.initialized_:
            raise RuntimeError(
                "Cannot initialize a modifier that has already been initialized"
            )

        if self.finalized_:
            raise RuntimeError(
                "Cannot initialize a modifier that has already been finalized"
            )

        self.initialized_ = self.on_initialize(state=state, **kwargs)

        # trigger starts
        fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
        if self.should_start(fake_start_event):
            self.on_start(state, fake_start_event, **kwargs)
            self.started_ = True

    def finalize(self, state: State, **kwargs):
        """
        Finalize the modifier for the given model and state.

        :raises RuntimeError: if the modifier has not been initialized
        :param state: The current state of the model
        :param kwargs: Additional arguments for finalizing the modifier
        """
        if self.finalized_:
            raise RuntimeError("cannot finalize a modifier twice")

        if not self.initialized_:
            raise RuntimeError("cannot finalize an uninitialized modifier")

        # TODO: all finalization should succeed
        self.finalized_ = self.on_finalize(state=state, **kwargs)

    def update_event(self, state: State, event: Event, **kwargs):
        """
        Update modifier based on the given event. In turn calls
        on_start, on_update, and on_end based on the event and
        modifier settings. Returns immediately if the modifier is
        not initialized

        :raises RuntimeError: if the modifier has been finalized
        :param state: The current state of sparsification
        :param event: The event to update the modifier with
        :param kwargs: Additional arguments for updating the modifier
        """
        if not self.initialized_:
            raise RuntimeError("Cannot update an uninitialized modifier")

        if self.finalized_:
            raise RuntimeError("Cannot update a finalized modifier")

        self.on_event(state, event, **kwargs)

        # handle starting the modifier if needed
        if (
            event.type_ == EventType.BATCH_START
            and not self.started_
            and self.should_start(event)
        ):
            self.on_start(state, event, **kwargs)
            self.started_ = True
            self.on_update(state, event, **kwargs)

            return

        # handle ending the modifier if needed
        if (
            event.type_ == EventType.BATCH_END
            and not self.ended_
            and self.should_end(event)
        ):
            self.on_end(state, event, **kwargs)
            self.ended_ = True
            self.on_update(state, event, **kwargs)

            return

        if self.started_ and not self.ended_:
            self.on_update(state, event, **kwargs)

    def should_start(self, event: Event) -> bool:
        """
        :param event: The event to check if the modifier should start
        :return: True if the modifier should start based on the given event
        """
        if self.start is None:
            return False

        current = event.current_index

        return self.start <= current and (self.end is None or current < self.end)

    def should_end(self, event: Event):
        """
        :param event: The event to check if the modifier should end
        :return: True if the modifier should end based on the given event
        """
        current = event.current_index

        return self.end is not None and current >= self.end

    @abstractmethod
    def on_initialize(self, state: State, **kwargs) -> bool:
        """
        on_initialize is called on modifier initialization and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param kwargs: Additional arguments for initializing the modifier
        :return: True if the modifier was initialized successfully,
            False otherwise
        """
        raise NotImplementedError()

    def on_finalize(self, state: State, **kwargs) -> bool:
        """
        on_finalize is called on modifier finalization and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param kwargs: Additional arguments for finalizing the modifier
        :return: True if the modifier was finalized successfully,
            False otherwise
        """
        return True

    def on_start(self, state: State, event: Event, **kwargs):
        """
        on_start is called when the modifier starts and
        must be implemented by the inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the start
        :param kwargs: Additional arguments for starting the modifier
        """
        pass

    def on_update(self, state: State, event: Event, **kwargs):
        """
        on_update is called when the model in question must be
        updated based on passed in event. Must be implemented by the
        inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the update
        :param kwargs: Additional arguments for updating the model
        """
        pass

    def on_end(self, state: State, event: Event, **kwargs):
        """
        on_end is called when the modifier ends and must be implemented
        by the inheriting modifier.

        :param state: The current state of the model
        :param event: The event that triggered the end
        :param kwargs: Additional arguments for ending the modifier
        """
        pass

    def on_event(self, state: State, event: Event, **kwargs):
        """
        on_event is called whenever an event is triggered

        :param state: The current state of the model
        :param event: The event that triggered the update
        :param kwargs: Additional arguments for updating the model
        """
        pass

finalized property

Returns:

Type Description
bool

True if the modifier has been finalized

initialized property

Returns:

Type Description
bool

True if the modifier has been initialized

finalize(state, **kwargs)

Finalize the modifier for the given model and state.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for finalizing the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has not been initialized

Source code in src/llmcompressor/modifiers/modifier.py
def finalize(self, state: State, **kwargs):
    """
    Finalize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has not been initialized
    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    """
    if self.finalized_:
        raise RuntimeError("cannot finalize a modifier twice")

    if not self.initialized_:
        raise RuntimeError("cannot finalize an uninitialized modifier")

    # TODO: all finalization should succeed
    self.finalized_ = self.on_finalize(state=state, **kwargs)

initialize(state, **kwargs)

Initialize the modifier for the given model and state.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for initializing the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has already been finalized

Source code in src/llmcompressor/modifiers/modifier.py
def initialize(self, state: State, **kwargs):
    """
    Initialize the modifier for the given model and state.

    :raises RuntimeError: if the modifier has already been finalized
    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    """
    if self.initialized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been initialized"
        )

    if self.finalized_:
        raise RuntimeError(
            "Cannot initialize a modifier that has already been finalized"
        )

    self.initialized_ = self.on_initialize(state=state, **kwargs)

    # trigger starts
    fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
    if self.should_start(fake_start_event):
        self.on_start(state, fake_start_event, **kwargs)
        self.started_ = True

on_end(state, event, **kwargs)

on_end is called when the modifier ends and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the end

required
kwargs

Additional arguments for ending the modifier

{}
Source code in src/llmcompressor/modifiers/modifier.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    on_end is called when the modifier ends and must be implemented
    by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the end
    :param kwargs: Additional arguments for ending the modifier
    """
    pass

on_event(state, event, **kwargs)

on_event is called whenever an event is triggered

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the update

required
kwargs

Additional arguments for updating the model

{}
Source code in src/llmcompressor/modifiers/modifier.py
def on_event(self, state: State, event: Event, **kwargs):
    """
    on_event is called whenever an event is triggered

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

on_finalize(state, **kwargs)

on_finalize is called on modifier finalization and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for finalizing the modifier

{}

Returns:

Type Description
bool

True if the modifier was finalized successfully, False otherwise

Source code in src/llmcompressor/modifiers/modifier.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    on_finalize is called on modifier finalization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for finalizing the modifier
    :return: True if the modifier was finalized successfully,
        False otherwise
    """
    return True

on_initialize(state, **kwargs) abstractmethod

on_initialize is called on modifier initialization and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional arguments for initializing the modifier

{}

Returns:

Type Description
bool

True if the modifier was initialized successfully, False otherwise

Source code in src/llmcompressor/modifiers/modifier.py
@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    on_initialize is called on modifier initialization and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param kwargs: Additional arguments for initializing the modifier
    :return: True if the modifier was initialized successfully,
        False otherwise
    """
    raise NotImplementedError()

on_start(state, event, **kwargs)

on_start is called when the modifier starts and must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the start

required
kwargs

Additional arguments for starting the modifier

{}
Source code in src/llmcompressor/modifiers/modifier.py
def on_start(self, state: State, event: Event, **kwargs):
    """
    on_start is called when the modifier starts and
    must be implemented by the inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the start
    :param kwargs: Additional arguments for starting the modifier
    """
    pass

on_update(state, event, **kwargs)

on_update is called when the model in question must be updated based on passed in event. Must be implemented by the inheriting modifier.

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event that triggered the update

required
kwargs

Additional arguments for updating the model

{}
Source code in src/llmcompressor/modifiers/modifier.py
def on_update(self, state: State, event: Event, **kwargs):
    """
    on_update is called when the model in question must be
    updated based on passed in event. Must be implemented by the
    inheriting modifier.

    :param state: The current state of the model
    :param event: The event that triggered the update
    :param kwargs: Additional arguments for updating the model
    """
    pass

should_end(event)

Parameters:

Name Type Description Default
event Event

The event to check if the modifier should end

required

Returns:

Type Description

True if the modifier should end based on the given event

Source code in src/llmcompressor/modifiers/modifier.py
def should_end(self, event: Event):
    """
    :param event: The event to check if the modifier should end
    :return: True if the modifier should end based on the given event
    """
    current = event.current_index

    return self.end is not None and current >= self.end

should_start(event)

Parameters:

Name Type Description Default
event Event

The event to check if the modifier should start

required

Returns:

Type Description
bool

True if the modifier should start based on the given event

Source code in src/llmcompressor/modifiers/modifier.py
def should_start(self, event: Event) -> bool:
    """
    :param event: The event to check if the modifier should start
    :return: True if the modifier should start based on the given event
    """
    if self.start is None:
        return False

    current = event.current_index

    return self.start <= current and (self.end is None or current < self.end)

update_event(state, event, **kwargs)

Update modifier based on the given event. In turn calls on_start, on_update, and on_end based on the event and modifier settings. Returns immediately if the modifier is not initialized

Parameters:

Name Type Description Default
state State

The current state of sparsification

required
event Event

The event to update the modifier with

required
kwargs

Additional arguments for updating the modifier

{}

Raises:

Type Description
RuntimeError

if the modifier has been finalized

Source code in src/llmcompressor/modifiers/modifier.py
def update_event(self, state: State, event: Event, **kwargs):
    """
    Update modifier based on the given event. In turn calls
    on_start, on_update, and on_end based on the event and
    modifier settings. Returns immediately if the modifier is
    not initialized

    :raises RuntimeError: if the modifier has been finalized
    :param state: The current state of sparsification
    :param event: The event to update the modifier with
    :param kwargs: Additional arguments for updating the modifier
    """
    if not self.initialized_:
        raise RuntimeError("Cannot update an uninitialized modifier")

    if self.finalized_:
        raise RuntimeError("Cannot update a finalized modifier")

    self.on_event(state, event, **kwargs)

    # handle starting the modifier if needed
    if (
        event.type_ == EventType.BATCH_START
        and not self.started_
        and self.should_start(event)
    ):
        self.on_start(state, event, **kwargs)
        self.started_ = True
        self.on_update(state, event, **kwargs)

        return

    # handle ending the modifier if needed
    if (
        event.type_ == EventType.BATCH_END
        and not self.ended_
        and self.should_end(event)
    ):
        self.on_end(state, event, **kwargs)
        self.ended_ = True
        self.on_update(state, event, **kwargs)

        return

    if self.started_ and not self.ended_:
        self.on_update(state, event, **kwargs)

ModifierFactory

A factory for loading and registering modifiers

Source code in src/llmcompressor/modifiers/factory.py
class ModifierFactory:
    """
    A factory for loading and registering modifiers
    """

    _MAIN_PACKAGE_PATH = "llmcompressor.modifiers"
    _EXPERIMENTAL_PACKAGE_PATH = "llmcompressor.modifiers.experimental"

    _loaded: bool = False
    _main_registry: Dict[str, Type[Modifier]] = {}
    _experimental_registry: Dict[str, Type[Modifier]] = {}
    _registered_registry: Dict[str, Type[Modifier]] = {}
    _errors: Dict[str, Exception] = {}

    @staticmethod
    def refresh():
        """
        A method to refresh the factory by reloading the modifiers
        Note: this will clear any previously registered modifiers
        """
        ModifierFactory._main_registry = ModifierFactory.load_from_package(
            ModifierFactory._MAIN_PACKAGE_PATH
        )
        ModifierFactory._experimental_registry = ModifierFactory.load_from_package(
            ModifierFactory._EXPERIMENTAL_PACKAGE_PATH
        )
        ModifierFactory._loaded = True

    @staticmethod
    def load_from_package(package_path: str) -> Dict[str, Type[Modifier]]:
        """
        :param package_path: The path to the package to load modifiers from
        :return: The loaded modifiers, as a mapping of name to class
        """
        loaded = {}
        main_package = importlib.import_module(package_path)

        for importer, modname, is_pkg in pkgutil.walk_packages(
            main_package.__path__, package_path + "."
        ):
            try:
                module = importlib.import_module(modname)

                for attribute_name in dir(module):
                    if not attribute_name.endswith("Modifier"):
                        continue

                    try:
                        if attribute_name in loaded:
                            continue

                        attr = getattr(module, attribute_name)

                        if not isinstance(attr, type):
                            raise ValueError(
                                f"Attribute {attribute_name} is not a type"
                            )

                        if not issubclass(attr, Modifier):
                            raise ValueError(
                                f"Attribute {attribute_name} is not a Modifier"
                            )

                        loaded[attribute_name] = attr
                    except Exception as err:
                        # TODO: log import error
                        ModifierFactory._errors[attribute_name] = err
            except Exception as module_err:
                # TODO: log import error
                print(module_err)

        return loaded

    @staticmethod
    def create(
        type_: str,
        allow_registered: bool,
        allow_experimental: bool,
        **kwargs,
    ) -> Modifier:
        """
        Instantiate a modifier of the given type from registered modifiers.

        :raises ValueError: If no modifier of the given type is found
        :param type_: The type of modifier to create
        :param framework: The framework the modifier is for
        :param allow_registered: Whether or not to allow registered modifiers
        :param allow_experimental: Whether or not to allow experimental modifiers
        :param kwargs: Additional keyword arguments to pass to the modifier
            during instantiation
        :return: The instantiated modifier
        """
        if type_ in ModifierFactory._errors:
            raise ModifierFactory._errors[type_]

        if type_ in ModifierFactory._registered_registry:
            if allow_registered:
                return ModifierFactory._registered_registry[type_](**kwargs)
            else:
                # TODO: log warning that modifier was skipped
                pass

        if type_ in ModifierFactory._experimental_registry:
            if allow_experimental:
                return ModifierFactory._experimental_registry[type_](**kwargs)
            else:
                # TODO: log warning that modifier was skipped
                pass

        if type_ in ModifierFactory._main_registry:
            return ModifierFactory._main_registry[type_](**kwargs)

        raise ValueError(f"No modifier of type '{type_}' found.")

    @staticmethod
    def register(type_: str, modifier_class: Type[Modifier]):
        """
        Register a modifier class to be used by the factory.

        :raises ValueError: If the provided class does not subclass the Modifier
            base class or is not a type
        :param type_: The type of modifier to register
        :param modifier_class: The class of the modifier to register, must subclass
            the Modifier base class
        """
        if not issubclass(modifier_class, Modifier):
            raise ValueError(
                "The provided class does not subclass the Modifier base class."
            )
        if not isinstance(modifier_class, type):
            raise ValueError("The provided class is not a type.")

        ModifierFactory._registered_registry[type_] = modifier_class

create(type_, allow_registered, allow_experimental, **kwargs) staticmethod

Instantiate a modifier of the given type from registered modifiers.

Parameters:

Name Type Description Default
type_ str

The type of modifier to create

required
framework

The framework the modifier is for

required
allow_registered bool

Whether or not to allow registered modifiers

required
allow_experimental bool

Whether or not to allow experimental modifiers

required
kwargs

Additional keyword arguments to pass to the modifier during instantiation

{}

Returns:

Type Description
Modifier

The instantiated modifier

Raises:

Type Description
ValueError

If no modifier of the given type is found

Source code in src/llmcompressor/modifiers/factory.py
@staticmethod
def create(
    type_: str,
    allow_registered: bool,
    allow_experimental: bool,
    **kwargs,
) -> Modifier:
    """
    Instantiate a modifier of the given type from registered modifiers.

    :raises ValueError: If no modifier of the given type is found
    :param type_: The type of modifier to create
    :param framework: The framework the modifier is for
    :param allow_registered: Whether or not to allow registered modifiers
    :param allow_experimental: Whether or not to allow experimental modifiers
    :param kwargs: Additional keyword arguments to pass to the modifier
        during instantiation
    :return: The instantiated modifier
    """
    if type_ in ModifierFactory._errors:
        raise ModifierFactory._errors[type_]

    if type_ in ModifierFactory._registered_registry:
        if allow_registered:
            return ModifierFactory._registered_registry[type_](**kwargs)
        else:
            # TODO: log warning that modifier was skipped
            pass

    if type_ in ModifierFactory._experimental_registry:
        if allow_experimental:
            return ModifierFactory._experimental_registry[type_](**kwargs)
        else:
            # TODO: log warning that modifier was skipped
            pass

    if type_ in ModifierFactory._main_registry:
        return ModifierFactory._main_registry[type_](**kwargs)

    raise ValueError(f"No modifier of type '{type_}' found.")

load_from_package(package_path) staticmethod

Parameters:

Name Type Description Default
package_path str

The path to the package to load modifiers from

required

Returns:

Type Description
Dict[str, Type[Modifier]]

The loaded modifiers, as a mapping of name to class

Source code in src/llmcompressor/modifiers/factory.py
@staticmethod
def load_from_package(package_path: str) -> Dict[str, Type[Modifier]]:
    """
    :param package_path: The path to the package to load modifiers from
    :return: The loaded modifiers, as a mapping of name to class
    """
    loaded = {}
    main_package = importlib.import_module(package_path)

    for importer, modname, is_pkg in pkgutil.walk_packages(
        main_package.__path__, package_path + "."
    ):
        try:
            module = importlib.import_module(modname)

            for attribute_name in dir(module):
                if not attribute_name.endswith("Modifier"):
                    continue

                try:
                    if attribute_name in loaded:
                        continue

                    attr = getattr(module, attribute_name)

                    if not isinstance(attr, type):
                        raise ValueError(
                            f"Attribute {attribute_name} is not a type"
                        )

                    if not issubclass(attr, Modifier):
                        raise ValueError(
                            f"Attribute {attribute_name} is not a Modifier"
                        )

                    loaded[attribute_name] = attr
                except Exception as err:
                    # TODO: log import error
                    ModifierFactory._errors[attribute_name] = err
        except Exception as module_err:
            # TODO: log import error
            print(module_err)

    return loaded

refresh() staticmethod

A method to refresh the factory by reloading the modifiers Note: this will clear any previously registered modifiers

Source code in src/llmcompressor/modifiers/factory.py
@staticmethod
def refresh():
    """
    A method to refresh the factory by reloading the modifiers
    Note: this will clear any previously registered modifiers
    """
    ModifierFactory._main_registry = ModifierFactory.load_from_package(
        ModifierFactory._MAIN_PACKAGE_PATH
    )
    ModifierFactory._experimental_registry = ModifierFactory.load_from_package(
        ModifierFactory._EXPERIMENTAL_PACKAGE_PATH
    )
    ModifierFactory._loaded = True

register(type_, modifier_class) staticmethod

Register a modifier class to be used by the factory.

Parameters:

Name Type Description Default
type_ str

The type of modifier to register

required
modifier_class Type[Modifier]

The class of the modifier to register, must subclass the Modifier base class

required

Raises:

Type Description
ValueError

If the provided class does not subclass the Modifier base class or is not a type

Source code in src/llmcompressor/modifiers/factory.py
@staticmethod
def register(type_: str, modifier_class: Type[Modifier]):
    """
    Register a modifier class to be used by the factory.

    :raises ValueError: If the provided class does not subclass the Modifier
        base class or is not a type
    :param type_: The type of modifier to register
    :param modifier_class: The class of the modifier to register, must subclass
        the Modifier base class
    """
    if not issubclass(modifier_class, Modifier):
        raise ValueError(
            "The provided class does not subclass the Modifier base class."
        )
    if not isinstance(modifier_class, type):
        raise ValueError("The provided class is not a type.")

    ModifierFactory._registered_registry[type_] = modifier_class

ModifierInterface

Bases: ABC

Defines the contract that all modifiers must implement

Source code in src/llmcompressor/modifiers/interface.py
class ModifierInterface(ABC):
    """
    Defines the contract that all modifiers must implement
    """

    @property
    @abstractmethod
    def initialized(self) -> bool:
        """
        :return: True if the modifier has been initialized
        """
        raise NotImplementedError()

    @property
    @abstractmethod
    def finalized(self) -> bool:
        """
        :return: True if the modifier has been finalized
        """
        raise NotImplementedError()

    @abstractmethod
    def initialize(self, state: State, **kwargs):
        """
        Initialize the modifier

        :param state: The current state of the model
        :param kwargs: Additional keyword arguments
            for modifier initialization
        """
        raise NotImplementedError()

    @abstractmethod
    def finalize(self, state: State, **kwargs):
        """
        Finalize the modifier

        :param state: The current state of the model
        :param kwargs: Additional keyword arguments for
            modifier finalization
        """
        raise NotImplementedError()

    @abstractmethod
    def update_event(self, state: State, event: Event, **kwargs):
        """
        Update the modifier based on the event

        :param state: The current state of the model
        :param event: The event to update the modifier with
        :param kwargs: Additional keyword arguments for
            modifier update
        """
        raise NotImplementedError()

finalized abstractmethod property

Returns:

Type Description
bool

True if the modifier has been finalized

initialized abstractmethod property

Returns:

Type Description
bool

True if the modifier has been initialized

finalize(state, **kwargs) abstractmethod

Finalize the modifier

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional keyword arguments for modifier finalization

{}
Source code in src/llmcompressor/modifiers/interface.py
@abstractmethod
def finalize(self, state: State, **kwargs):
    """
    Finalize the modifier

    :param state: The current state of the model
    :param kwargs: Additional keyword arguments for
        modifier finalization
    """
    raise NotImplementedError()

initialize(state, **kwargs) abstractmethod

Initialize the modifier

Parameters:

Name Type Description Default
state State

The current state of the model

required
kwargs

Additional keyword arguments for modifier initialization

{}
Source code in src/llmcompressor/modifiers/interface.py
@abstractmethod
def initialize(self, state: State, **kwargs):
    """
    Initialize the modifier

    :param state: The current state of the model
    :param kwargs: Additional keyword arguments
        for modifier initialization
    """
    raise NotImplementedError()

update_event(state, event, **kwargs) abstractmethod

Update the modifier based on the event

Parameters:

Name Type Description Default
state State

The current state of the model

required
event Event

The event to update the modifier with

required
kwargs

Additional keyword arguments for modifier update

{}
Source code in src/llmcompressor/modifiers/interface.py
@abstractmethod
def update_event(self, state: State, event: Event, **kwargs):
    """
    Update the modifier based on the event

    :param state: The current state of the model
    :param event: The event to update the modifier with
    :param kwargs: Additional keyword arguments for
        modifier update
    """
    raise NotImplementedError()

StageModifiers

Bases: ModifierInterface, BaseModel

Represents a collection of modifiers that are applied together as a stage.

Parameters:

Name Type Description Default
modifiers

The modifiers to apply as a stage

required
index

The index of the stage, if applicable

required
group

The group name of the stage, if applicable

required
applied

Flag for indicating if this stage has has already been applied to the model through finalization

required
Source code in src/llmcompressor/modifiers/stage.py
class StageModifiers(ModifierInterface, BaseModel):
    """
    Represents a collection of modifiers that are applied together as a stage.

    :param modifiers: The modifiers to apply as a stage
    :param index: The index of the stage, if applicable
    :param group: The group name of the stage, if applicable
    :param applied: Flag for indicating if this stage has has already been
    applied to the model through finalization
    """

    modifiers: List["Modifier"] = Field(default_factory=list)
    index: Optional[int] = None
    group: Optional[str] = None
    applied: bool = False

    @property
    def initialized(self) -> bool:
        """
        :return: True if all of the stage modifiers have been initialized,
            False otherwise
        """
        return all(mod.initialized for mod in self.modifiers)

    @property
    def finalized(self) -> bool:
        """
        :return: True if all of the stage modifiers have been finalized,
            False otherwise
        """
        return all(mod.finalized for mod in self.modifiers)

    @property
    def unique_id(self) -> str:
        """
        :return: ID for stage containing the name and index
        """
        return self.group + "_" + str(self.index)

    def initialize(self, state: "State", **kwargs):
        """
        Initialize all the stage modifiers

        :param state: The state of current session
        :param kwargs: Additional kwargs to pass to the modifier(s)
            initialize method
        """

        if self.applied:
            return

        accelerator = kwargs.get("accelerator", None)
        for modifier in self.modifiers:
            if not modifier.initialized:
                modifier.initialize(state, **kwargs)
            if accelerator:
                accelerator.wait_for_everyone()
        state.loggers.system.info(tag="stage", string="Modifiers initialized")

    def finalize(self, state: "State", **kwargs):
        """
        Finalize all the stage modifiers and mark the stage as applied

        :param state: The state of current session
        :param kwargs: Additional kwargs to pass to the modifier(s)
            finalize method
        """

        if self.applied:
            return

        accelerator = kwargs.get("accelerator", None)
        for modifier in self.modifiers:
            modifier.finalize(state, **kwargs)
            if accelerator:
                accelerator.wait_for_everyone()

        self.applied = True
        state.loggers.system.info(tag="stage", string="Modifiers finalized")

    def update_event(self, state: "State", event: "Event", **kwargs):
        """
        Propagate the event to all the stage modifiers

        :param state: The state of current session
        :param event: The event to propagate
        :param kwargs: Additional kwargs to pass to the modifier(s)
            update_event method
        """

        if self.applied:
            return

        for modifier in self.modifiers:
            modifier.update_event(state, event, **kwargs)

finalized property

Returns:

Type Description
bool

True if all of the stage modifiers have been finalized, False otherwise

initialized property

Returns:

Type Description
bool

True if all of the stage modifiers have been initialized, False otherwise

unique_id property

Returns:

Type Description
str

ID for stage containing the name and index

finalize(state, **kwargs)

Finalize all the stage modifiers and mark the stage as applied

Parameters:

Name Type Description Default
state State

The state of current session

required
kwargs

Additional kwargs to pass to the modifier(s) finalize method

{}
Source code in src/llmcompressor/modifiers/stage.py
def finalize(self, state: "State", **kwargs):
    """
    Finalize all the stage modifiers and mark the stage as applied

    :param state: The state of current session
    :param kwargs: Additional kwargs to pass to the modifier(s)
        finalize method
    """

    if self.applied:
        return

    accelerator = kwargs.get("accelerator", None)
    for modifier in self.modifiers:
        modifier.finalize(state, **kwargs)
        if accelerator:
            accelerator.wait_for_everyone()

    self.applied = True
    state.loggers.system.info(tag="stage", string="Modifiers finalized")

initialize(state, **kwargs)

Initialize all the stage modifiers

Parameters:

Name Type Description Default
state State

The state of current session

required
kwargs

Additional kwargs to pass to the modifier(s) initialize method

{}
Source code in src/llmcompressor/modifiers/stage.py
def initialize(self, state: "State", **kwargs):
    """
    Initialize all the stage modifiers

    :param state: The state of current session
    :param kwargs: Additional kwargs to pass to the modifier(s)
        initialize method
    """

    if self.applied:
        return

    accelerator = kwargs.get("accelerator", None)
    for modifier in self.modifiers:
        if not modifier.initialized:
            modifier.initialize(state, **kwargs)
        if accelerator:
            accelerator.wait_for_everyone()
    state.loggers.system.info(tag="stage", string="Modifiers initialized")

update_event(state, event, **kwargs)

Propagate the event to all the stage modifiers

Parameters:

Name Type Description Default
state State

The state of current session

required
event Event

The event to propagate

required
kwargs

Additional kwargs to pass to the modifier(s) update_event method

{}
Source code in src/llmcompressor/modifiers/stage.py
def update_event(self, state: "State", event: "Event", **kwargs):
    """
    Propagate the event to all the stage modifiers

    :param state: The state of current session
    :param event: The event to propagate
    :param kwargs: Additional kwargs to pass to the modifier(s)
        update_event method
    """

    if self.applied:
        return

    for modifier in self.modifiers:
        modifier.update_event(state, event, **kwargs)