Run a sequential data pipeline according to the following steps:
- The model is partitioned into subgraphs according to
sequential_targets
- Data passes through each subgraph sequentially. Data is passed through each subgraph twice, once to trigger calibration hooks, then a second time in order to capture activations after quantization has occurred through hooks.
- The intermediate activations between each subgraph are cached and offloaded to the cpu between each batch in order to save memory
This pipeline requires that the model be traceable with respect to data from the data loader. This may be an issue for vision models with vision datasets, due to specialized input processing in the model.
In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A model can be made traceable by wrapping the untraceable functions (see llmcompressor.transformers.tracing)
Parameters:
Name | Type | Description | Default |
model | Module | | required |
dataloader | DataLoader | loads data for calibration | required |
dataset_args | DatasetArguments | dataset arguments relevant to pipelines | required |
Source code in src/llmcompressor/pipelines/sequential/pipeline.py
| @staticmethod
def __call__(
model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments"
):
"""
Run a sequential data pipeline according to the following steps:
1. The model is partitioned into subgraphs according to `sequential_targets`
2. Data passes through each subgraph sequentially. Data is passed through each
subgraph twice, once to trigger calibration hooks, then a second time in
order to capture activations after quantization has occurred through hooks.
3. The intermediate activations between each subgraph are cached and offloaded
to the cpu between each batch in order to save memory
This pipeline requires that the model be traceable with respect to data from the
data loader. This may be an issue for vision models with vision datasets, due
to specialized input processing in the model.
In the event that tracing fails, a torch.fx.proxy.TraceError will be raised. A
model can be made traceable by wrapping the untraceable functions (see
llmcompressor.transformers.tracing)
:param model: model being calibrated
:param dataloader: loads data for calibration
:param dataset_args: dataset arguments relevant to pipelines
"""
session = active_session()
# infer sequential targets
modifiers = session.get_modifiers()
sequential_targets, ignore = get_targets_from_modifiers(modifiers, model)
# trace subgraphs
sample_input = next(iter(dataloader))
subgraphs = trace_subgraphs(model, sample_input, sequential_targets, ignore)
LifecycleCallbacks.calibration_epoch_start()
with calibration_forward_context(model), DisableQuantization(model):
# prepare intermediates cache
model_device = get_execution_device(model)
intermediates = IntermediatesCache.from_dataloader(dataloader, model_device)
num_subgraphs = len(subgraphs)
for subgraph_index, subgraph in enumerate(subgraphs):
# prepare tqdm description texts
calib_desc = f"({subgraph_index + 1}/{num_subgraphs}): Calibrating"
prop_desc = f"({subgraph_index + 1}/{num_subgraphs}): Propagating"
# do an preliminary pass to trigger modifier hooks
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc):
inputs = intermediates.fetch(batch_idx, subgraph.input_names)
subgraph.forward(model, **inputs)
# trigger compression
LifecycleCallbacks.sequential_epoch_end()
# this pass does not trigger modifier hooks
# and is only used for capturing outputs from newly compressed modules
with HooksMixin.disable_hooks():
for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=prop_desc):
inputs = intermediates.fetch(batch_idx, subgraph.input_names)
output = subgraph.forward(model, **inputs)
if subgraph_index < num_subgraphs - 1:
intermediates.update(batch_idx, output)
intermediates.delete(batch_idx, subgraph.consumed_names)
# redudant, finish any remaining compression
LifecycleCallbacks.calibration_epoch_end()
|