Skip to content

llmcompressor.pipelines

BasicPipeline

Bases: CalibrationPipeline

Source code in src/llmcompressor/pipelines/basic/pipeline.py
@CalibrationPipeline.register("basic")
class BasicPipeline(CalibrationPipeline):
    @staticmethod
    def __call__(
        model: torch.nn.Module,
        dataloader: DataLoader,
        dataset_args: Union["DatasetArguments", None],
    ):
        """
        Run a basic data pipeline.

        Batches are fetched from the data loader and are used to perform forward passes
        through the model. This pipeline is typically used for basic model calibration
        and, unlike the sequential pipelines, does not propagate compression error when
        used to calibrate model compression

        :param model: model being calibrated
        :param dataloader: loads data for calibration
        :param dataset_args: dataset arguments relevant to pipelines
        """
        model_device = get_execution_device(model)

        LifecycleCallbacks.calibration_epoch_start()

        with calibration_forward_context(model):
            for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
                batch = apply_pad_mask_to_batch(batch)
                batch = tensors_to_device(batch, model_device)
                model(**batch)

        LifecycleCallbacks.calibration_epoch_end()

__call__(model, dataloader, dataset_args) staticmethod

Run a basic data pipeline.

Batches are fetched from the data loader and are used to perform forward passes through the model. This pipeline is typically used for basic model calibration and, unlike the sequential pipelines, does not propagate compression error when used to calibrate model compression

Parameters:

Name Type Description Default
model Module

model being calibrated

required
dataloader DataLoader

loads data for calibration

required
dataset_args Union[DatasetArguments, None]

dataset arguments relevant to pipelines

required
Source code in src/llmcompressor/pipelines/basic/pipeline.py
@staticmethod
def __call__(
    model: torch.nn.Module,
    dataloader: DataLoader,
    dataset_args: Union["DatasetArguments", None],
):
    """
    Run a basic data pipeline.

    Batches are fetched from the data loader and are used to perform forward passes
    through the model. This pipeline is typically used for basic model calibration
    and, unlike the sequential pipelines, does not propagate compression error when
    used to calibrate model compression

    :param model: model being calibrated
    :param dataloader: loads data for calibration
    :param dataset_args: dataset arguments relevant to pipelines
    """
    model_device = get_execution_device(model)

    LifecycleCallbacks.calibration_epoch_start()

    with calibration_forward_context(model):
        for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
            batch = apply_pad_mask_to_batch(batch)
            batch = tensors_to_device(batch, model_device)
            model(**batch)

    LifecycleCallbacks.calibration_epoch_end()

CalibrationPipeline

Bases: ABC, RegistryMixin

Source code in src/llmcompressor/pipelines/registry.py
class CalibrationPipeline(ABC, RegistryMixin):
    @staticmethod
    @abstractmethod
    def __call__(
        model: torch.nn.Module,
        dataloader: DataLoader,
        dataset_args: "DatasetArguments",
    ):
        raise NotImplementedError()

    @classmethod
    def from_modifiers(
        cls, modifiers: List[Modifier], user: Optional[str] = None
    ) -> "CalibrationPipeline":
        """
        Infer which calibration pipeline to use based on the available modifiers and
        any user specifications

        :param modifiers: modifiers to apply to model
        :param user: pipeline name passed by user
        :return: CalibrationPipeline instance to be called with data (if not datafree)
        """
        user = standardize_lookup_name(user) if user else None
        inferred = standardize_lookup_name(cls._validate_infer_pipeline(modifiers))
        independent = standardize_lookup_name("independent")

        if user == independent:
            inferred = independent

        if user is not None and user != inferred:
            logger.warning(
                f"Calibration pipeline is set to `{user}`, but it is recommended to "
                f"use `{inferred}`"
            )

        pipeline = user or inferred
        return cls.load_from_registry(pipeline)

    @staticmethod
    def _validate_infer_pipeline(modifiers: List[Modifier]) -> str:
        if any(isinstance(modifier, AWQModifier) for modifier in modifiers):
            if len(modifiers) > 1:
                logger.warning(
                    "AWQ does not currently support sharing a data pipeline with other "
                    "modifiers. Inferring `independent` calibration pipeline"
                )
                return "independent"
            return "datafree"

        if any(isinstance(modifier, SEQUENTIAL_MODIFIERS) for modifier in modifiers):
            return "sequential"

        active_qmods = _get_active_quant_modifiers(modifiers)
        if len(active_qmods) > 1:
            raise ValueError(
                f"Recipe contains more than one active quantization config "
                f"({active_qmods}). These configs may be conflicting, Please modify "
                "your recipe to use at most one quantization config"
            )

        if len(active_qmods) == 1:
            quant_modifier = active_qmods[0]
            config = quant_modifier.resolve_quantization_config()
            if config.requires_calibration_data():
                return "basic"
            else:
                return "datafree"

        if any(isinstance(modifier, SmoothQuantModifier) for modifier in modifiers):
            return "basic"

        return "datafree"

from_modifiers(modifiers, user=None) classmethod

Infer which calibration pipeline to use based on the available modifiers and any user specifications

Parameters:

Name Type Description Default
modifiers List[Modifier]

modifiers to apply to model

required
user Optional[str]

pipeline name passed by user

None

Returns:

Type Description
CalibrationPipeline

CalibrationPipeline instance to be called with data (if not datafree)

Source code in src/llmcompressor/pipelines/registry.py
@classmethod
def from_modifiers(
    cls, modifiers: List[Modifier], user: Optional[str] = None
) -> "CalibrationPipeline":
    """
    Infer which calibration pipeline to use based on the available modifiers and
    any user specifications

    :param modifiers: modifiers to apply to model
    :param user: pipeline name passed by user
    :return: CalibrationPipeline instance to be called with data (if not datafree)
    """
    user = standardize_lookup_name(user) if user else None
    inferred = standardize_lookup_name(cls._validate_infer_pipeline(modifiers))
    independent = standardize_lookup_name("independent")

    if user == independent:
        inferred = independent

    if user is not None and user != inferred:
        logger.warning(
            f"Calibration pipeline is set to `{user}`, but it is recommended to "
            f"use `{inferred}`"
        )

    pipeline = user or inferred
    return cls.load_from_registry(pipeline)

DataFreePipeline

Bases: CalibrationPipeline

Source code in src/llmcompressor/pipelines/data_free/pipeline.py
@CalibrationPipeline.register("datafree")
class DataFreePipeline(CalibrationPipeline):
    @staticmethod
    def __call__(
        model: torch.nn.Module,
        dataloader: Optional[DataLoader],
        dataset_args: "DatasetArguments",
    ):
        """
        A pipeline for data-free calibration

        :param model: model being calibrated
        :param dataloader: loads data for calibration
        :param dataset_args: dataset arguments relevant to pipelines
        """
        LifecycleCallbacks.calibration_epoch_start()
        LifecycleCallbacks.calibration_epoch_end()

__call__(model, dataloader, dataset_args) staticmethod

A pipeline for data-free calibration

Parameters:

Name Type Description Default
model Module

model being calibrated

required
dataloader Optional[DataLoader]

loads data for calibration

required
dataset_args DatasetArguments

dataset arguments relevant to pipelines

required
Source code in src/llmcompressor/pipelines/data_free/pipeline.py
@staticmethod
def __call__(
    model: torch.nn.Module,
    dataloader: Optional[DataLoader],
    dataset_args: "DatasetArguments",
):
    """
    A pipeline for data-free calibration

    :param model: model being calibrated
    :param dataloader: loads data for calibration
    :param dataset_args: dataset arguments relevant to pipelines
    """
    LifecycleCallbacks.calibration_epoch_start()
    LifecycleCallbacks.calibration_epoch_end()

IndependentPipeline

Bases: CalibrationPipeline

Source code in src/llmcompressor/pipelines/independent/pipeline.py
@CalibrationPipeline.register("independent")
class IndependentPipeline(CalibrationPipeline):
    @staticmethod
    def __call__(
        model: torch.nn.Module,
        dataloader: DataLoader,
        dataset_args: "DatasetArguments",
    ):
        """
        Data pipeline where each modifier is assigned its own calibration epoch and data
        pipeline

        :param model: model being calibrated
        :param dataloader: loads data for calibration
        :param dataset_args: dataset arguments relevant to pipelines
        """
        _logger = logger.patch(lambda r: r.update(function="IndependentPipeline"))

        session = active_session()
        modifiers = session.get_modifiers()
        with patch_attr(session.lifecycle, "modifiers", None):
            for index, modifier in enumerate(modifiers):
                mod_type = str(type(modifier).__name__)
                session.lifecycle.modifiers = [
                    StageModifiers(modifiers=[modifier], group=mod_type, index=index)
                ]

                pipeline = CalibrationPipeline.from_modifiers([modifier])
                pipeline_name = pipeline.__class__.__name__
                _logger.info(f"Inferred `{pipeline_name}` for `{mod_type}`")

                pipeline(model, dataloader, dataset_args)

__call__(model, dataloader, dataset_args) staticmethod

Data pipeline where each modifier is assigned its own calibration epoch and data pipeline

Parameters:

Name Type Description Default
model Module

model being calibrated

required
dataloader DataLoader

loads data for calibration

required
dataset_args DatasetArguments

dataset arguments relevant to pipelines

required
Source code in src/llmcompressor/pipelines/independent/pipeline.py
@staticmethod
def __call__(
    model: torch.nn.Module,
    dataloader: DataLoader,
    dataset_args: "DatasetArguments",
):
    """
    Data pipeline where each modifier is assigned its own calibration epoch and data
    pipeline

    :param model: model being calibrated
    :param dataloader: loads data for calibration
    :param dataset_args: dataset arguments relevant to pipelines
    """
    _logger = logger.patch(lambda r: r.update(function="IndependentPipeline"))

    session = active_session()
    modifiers = session.get_modifiers()
    with patch_attr(session.lifecycle, "modifiers", None):
        for index, modifier in enumerate(modifiers):
            mod_type = str(type(modifier).__name__)
            session.lifecycle.modifiers = [
                StageModifiers(modifiers=[modifier], group=mod_type, index=index)
            ]

            pipeline = CalibrationPipeline.from_modifiers([modifier])
            pipeline_name = pipeline.__class__.__name__
            _logger.info(f"Inferred `{pipeline_name}` for `{mod_type}`")

            pipeline(model, dataloader, dataset_args)

LayerSequentialPipeline

Bases: CalibrationPipeline

Source code in src/llmcompressor/pipelines/layer_sequential/pipeline.py
@CalibrationPipeline.register("layer_sequential")
class LayerSequentialPipeline(CalibrationPipeline):
    @staticmethod
    def __call__(
        model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments"
    ):
        """
        Run a layer-wise sequential data pipeline according to the following steps:

        1. Layers are identified according to `sequential_targets`
        2. A hook is attached to the first layer. This hook raises an exception which is
            then caught and used to capture the input arguments to the first layer
        3. The inputs to the first layer are used to calibrate the first layer, and the
            output of the previous layer is used as inputs to calibrate the next layer

        This pipeline requires that the model have distinct layers defined in its
        architecture and that the outputs of the previous layer are exactly the inputs
        to the next layer. This is violated by encoder-decoder architectures, among
        others.

        If your model architecture violates these assumptions, consider using the
        sequential pipeline (see llmcompressor.pipelines.sequential). Architectures
        which are known to fail these assumptions include GPT-J and most vision models

        :param model: model being calibrated
        :param dataloader: loads data for calibration
        :param dataset_args: dataset arguments relevant to pipelines
        """
        session = active_session()

        # find layers
        modifiers = session.get_modifiers()
        sequential_targets, _ = get_targets_from_modifiers(modifiers, model)
        layers = match_modules(model, sequential_targets)

        LifecycleCallbacks.calibration_epoch_start()

        with calibration_forward_context(model), DisableQuantization(model):
            # prepare intermediates cache
            intermediates: IntermediatesCache = capture_first_layer_intermediates(
                model, layers[0], dataloader
            )

            num_layers = len(layers)
            for layer_index, layer in enumerate(layers):
                # prepare tqdm description texts
                calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating"
                prop_desc = f"({layer_index + 1}/{num_layers}): Propagating"

                # do an preliminary pass to trigger modifier hooks
                for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
                    inputs = intermediates.fetch(batch_idx)
                    layer(**inputs)

                # trigger compression
                LifecycleCallbacks.sequential_epoch_end()

                # this pass does not trigger modifier hooks
                # and is only used for capturing outputs from newly compressed modules
                with HooksMixin.disable_hooks():
                    for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
                        inputs = intermediates.fetch(batch_idx)
                        output = layer(**inputs)

                        if layer_index < num_layers - 1:
                            next_layer = layers[layer_index + 1]
                            output = to_next_layer_kwargs(output, next_layer)
                            output = maybe_inject_pos_embeddings(
                                output, next_layer, inputs
                            )

                            intermediates.delete(batch_idx)
                            intermediates.update(batch_idx, output)

            # redudant, finish any remaining compression
            LifecycleCallbacks.calibration_epoch_end()

__call__(model, dataloader, dataset_args) staticmethod

Run a layer-wise sequential data pipeline according to the following steps:

  1. Layers are identified according to sequential_targets
  2. A hook is attached to the first layer. This hook raises an exception which is then caught and used to capture the input arguments to the first layer
  3. The inputs to the first layer are used to calibrate the first layer, and the output of the previous layer is used as inputs to calibrate the next layer

This pipeline requires that the model have distinct layers defined in its architecture and that the outputs of the previous layer are exactly the inputs to the next layer. This is violated by encoder-decoder architectures, among others.

If your model architecture violates these assumptions, consider using the sequential pipeline (see llmcompressor.pipelines.sequential). Architectures which are known to fail these assumptions include GPT-J and most vision models

Parameters:

Name Type Description Default
model Module

model being calibrated

required
dataloader DataLoader

loads data for calibration

required
dataset_args DatasetArguments

dataset arguments relevant to pipelines

required
Source code in src/llmcompressor/pipelines/layer_sequential/pipeline.py
@staticmethod
def __call__(
    model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments"
):
    """
    Run a layer-wise sequential data pipeline according to the following steps:

    1. Layers are identified according to `sequential_targets`
    2. A hook is attached to the first layer. This hook raises an exception which is
        then caught and used to capture the input arguments to the first layer
    3. The inputs to the first layer are used to calibrate the first layer, and the
        output of the previous layer is used as inputs to calibrate the next layer

    This pipeline requires that the model have distinct layers defined in its
    architecture and that the outputs of the previous layer are exactly the inputs
    to the next layer. This is violated by encoder-decoder architectures, among
    others.

    If your model architecture violates these assumptions, consider using the
    sequential pipeline (see llmcompressor.pipelines.sequential). Architectures
    which are known to fail these assumptions include GPT-J and most vision models

    :param model: model being calibrated
    :param dataloader: loads data for calibration
    :param dataset_args: dataset arguments relevant to pipelines
    """
    session = active_session()

    # find layers
    modifiers = session.get_modifiers()
    sequential_targets, _ = get_targets_from_modifiers(modifiers, model)
    layers = match_modules(model, sequential_targets)

    LifecycleCallbacks.calibration_epoch_start()

    with calibration_forward_context(model), DisableQuantization(model):
        # prepare intermediates cache
        intermediates: IntermediatesCache = capture_first_layer_intermediates(
            model, layers[0], dataloader
        )

        num_layers = len(layers)
        for layer_index, layer in enumerate(layers):
            # prepare tqdm description texts
            calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating"
            prop_desc = f"({layer_index + 1}/{num_layers}): Propagating"

            # do an preliminary pass to trigger modifier hooks
            for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
                inputs = intermediates.fetch(batch_idx)
                layer(**inputs)

            # trigger compression
            LifecycleCallbacks.sequential_epoch_end()

            # this pass does not trigger modifier hooks
            # and is only used for capturing outputs from newly compressed modules
            with HooksMixin.disable_hooks():
                for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
                    inputs = intermediates.fetch(batch_idx)
                    output = layer(**inputs)

                    if layer_index < num_layers - 1:
                        next_layer = layers[layer_index + 1]
                        output = to_next_layer_kwargs(output, next_layer)
                        output = maybe_inject_pos_embeddings(
                            output, next_layer, inputs
                        )

                        intermediates.delete(batch_idx)
                        intermediates.update(batch_idx, output)

        # redudant, finish any remaining compression
        LifecycleCallbacks.calibration_epoch_end()

SequentialPipeline

Bases: CalibrationPipeline

Source code in src/llmcompressor/pipelines/sequential/pipeline.py
@CalibrationPipeline.register("sequential")
class SequentialPipeline(CalibrationPipeline):
    @staticmethod
    def __call__(
        model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments"
    ):
        """
        Run a sequential data pipeline according to the following steps:

        1. The model is partitioned into subgraphs according to `sequential_targets`
        2. Data passes through each subgraph sequentially. Data is passed through each
            subgraph twice, once to trigger calibration hooks, then a second time in
            order to capture activations after quantization has occurred through hooks.
        3. The intermediate activations between each subgraph are cached and offloaded
            to the cpu between each batch in order to save memory

        This pipeline requires that the model be traceable with respect to data from the
        data loader. This may be an issue for vision models with vision datasets, due
        to specialized input processing in the model.

        In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A
        model can be made traceable by wrapping the untraceable functions (see
        llmcompressor.transformers.tracing)

        :param model: model being calibrated
        :param dataloader: loads data for calibration
        :param dataset_args: dataset arguments relevant to pipelines
        """
        session = active_session()

        # infer sequential targets
        modifiers = session.get_modifiers()
        sequential_targets, ignore = get_targets_from_modifiers(modifiers, model)

        # trace subgraphs
        sample_input = next(iter(dataloader))
        subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)

        LifecycleCallbacks.calibration_epoch_start()

        with calibration_forward_context(model), DisableQuantization(model):
            # prepare intermediates cache
            model_device = get_execution_device(model)
            intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)

            num_subgraphs = len(subgraphs)
            for subgraph_index, subgraph in enumerate(subgraphs):
                # prepare tqdm description texts
                calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
                prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"

                # do an preliminary pass to trigger modifier hooks
                for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
                    inputs = intermediates.fetch(batch_idx, subgraph.input_names)
                    subgraph.forward(model, **inputs)

                # trigger compression
                LifecycleCallbacks.sequential_epoch_end()

                # this pass does not trigger modifier hooks
                # and is only used for capturing outputs from newly compressed modules
                with HooksMixin.disable_hooks():
                    for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
                        inputs = intermediates.fetch(batch_idx, subgraph.input_names)
                        output = subgraph.forward(model, **inputs)

                        if subgraph_index < num_subgraphs - 1:
                            intermediates.update(batch_idx, output)
                            intermediates.delete(batch_idx, subgraph.consumed_names)

            # redudant, finish any remaining compression
            LifecycleCallbacks.calibration_epoch_end()

__call__(model, dataloader, dataset_args) staticmethod

Run a sequential data pipeline according to the following steps:

  1. The model is partitioned into subgraphs according to sequential_targets
  2. Data passes through each subgraph sequentially. Data is passed through each subgraph twice, once to trigger calibration hooks, then a second time in order to capture activations after quantization has occurred through hooks.
  3. The intermediate activations between each subgraph are cached and offloaded to the cpu between each batch in order to save memory

This pipeline requires that the model be traceable with respect to data from the data loader. This may be an issue for vision models with vision datasets, due to specialized input processing in the model.

In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A model can be made traceable by wrapping the untraceable functions (see llmcompressor.transformers.tracing)

Parameters:

Name Type Description Default
model Module

model being calibrated

required
dataloader DataLoader

loads data for calibration

required
dataset_args DatasetArguments

dataset arguments relevant to pipelines

required
Source code in src/llmcompressor/pipelines/sequential/pipeline.py
@staticmethod
def __call__(
    model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments"
):
    """
    Run a sequential data pipeline according to the following steps:

    1. The model is partitioned into subgraphs according to `sequential_targets`
    2. Data passes through each subgraph sequentially. Data is passed through each
        subgraph twice, once to trigger calibration hooks, then a second time in
        order to capture activations after quantization has occurred through hooks.
    3. The intermediate activations between each subgraph are cached and offloaded
        to the cpu between each batch in order to save memory

    This pipeline requires that the model be traceable with respect to data from the
    data loader. This may be an issue for vision models with vision datasets, due
    to specialized input processing in the model.

    In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A
    model can be made traceable by wrapping the untraceable functions (see
    llmcompressor.transformers.tracing)

    :param model: model being calibrated
    :param dataloader: loads data for calibration
    :param dataset_args: dataset arguments relevant to pipelines
    """
    session = active_session()

    # infer sequential targets
    modifiers = session.get_modifiers()
    sequential_targets, ignore = get_targets_from_modifiers(modifiers, model)

    # trace subgraphs
    sample_input = next(iter(dataloader))
    subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)

    LifecycleCallbacks.calibration_epoch_start()

    with calibration_forward_context(model), DisableQuantization(model):
        # prepare intermediates cache
        model_device = get_execution_device(model)
        intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)

        num_subgraphs = len(subgraphs)
        for subgraph_index, subgraph in enumerate(subgraphs):
            # prepare tqdm description texts
            calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
            prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"

            # do an preliminary pass to trigger modifier hooks
            for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
                inputs = intermediates.fetch(batch_idx, subgraph.input_names)
                subgraph.forward(model, **inputs)

            # trigger compression
            LifecycleCallbacks.sequential_epoch_end()

            # this pass does not trigger modifier hooks
            # and is only used for capturing outputs from newly compressed modules
            with HooksMixin.disable_hooks():
                for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
                    inputs = intermediates.fetch(batch_idx, subgraph.input_names)
                    output = subgraph.forward(model, **inputs)

                    if subgraph_index < num_subgraphs - 1:
                        intermediates.update(batch_idx, output)
                        intermediates.delete(batch_idx, subgraph.consumed_names)

        # redudant, finish any remaining compression
        LifecycleCallbacks.calibration_epoch_end()

get_targets_from_modifiers(modifiers, model)

Infer sequential targets and ignore list from modifiers list

Parameters:

Name Type Description Default
model PreTrainedModel

model being calibrated

required
modifiers List[Modifier]

list of modifiers being applied during calibration

required

Returns:

Type Description
Tuple[List[str], List[str]]

list of sequential targets and list of modules to ignore for tracing

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def get_targets_from_modifiers(
    modifiers: List[Modifier], model: PreTrainedModel
) -> Tuple[List[str], List[str]]:
    """
    Infer sequential targets and ignore list from modifiers list

    :param model: model being calibrated
    :param modifiers: list of modifiers being applied during calibration
    :return: list of sequential targets and list of modules to ignore for tracing
    """
    # avoid circular import
    from llmcompressor.pipelines.registry import SEQUENTIAL_MODIFIERS

    sequential_modifiers = [
        modifier for modifier in modifiers if isinstance(modifier, SEQUENTIAL_MODIFIERS)
    ]

    if len(sequential_modifiers) >= 2:
        types = [type(modifier) for modifier in sequential_modifiers]
        logger.warning(
            "Cannot infer sequential targets from multiple sequential modifiers "
            f"({types}). Defaulting to {types[0]}"
        )
    elif len(sequential_modifiers) <= 0:
        types = [type(modifier) for modifier in modifiers]
        raise ValueError(f"Cannot infer sequential targets from list of {types}")

    modifier = sequential_modifiers[0]

    # infer sequential targets
    if modifier.sequential_targets is None:
        sequential_targets = get_no_split_params(model)
    elif isinstance(modifier.sequential_targets, str):
        sequential_targets = [modifier.sequential_targets]
    else:
        sequential_targets = modifier.sequential_targets

    return sequential_targets, modifier.ignore