Skip to content

llmcompressor.modifiers.quantization.calibration

calibrate_activations(module, value, base_name)

Calibrate input or output activations by calling the a module's attached observer.

Parameters:

Name Type Description Default
module Module

torch.nn.Module

required
base_name str

substring used to fetch the observer, scales, and zp

required
value Tensor

torch.Tensor to be passed to the observer

required
Source code in src/llmcompressor/modifiers/quantization/calibration.py
def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
    """
    Calibrate input or output activations by calling the a module's attached
    observer.

    :param module: torch.nn.Module
    :param base_name: substring used to fetch the observer, scales, and zp
    :param value: torch.Tensor to be passed to the observer

    """
    # If empty tensor, can't update zp/scale
    # Case for MoEs
    if value.numel() == 0:
        return

    call_observer(
        module=module,
        base_name=base_name,
        value=value,
    )

calibrate_input_hook(module, args)

Hook to calibrate input activations. Will call the observers to update the scales/zp before applying input QDQ in the module's forward pass.

Source code in src/llmcompressor/modifiers/quantization/calibration.py
def calibrate_input_hook(module: Module, args: Any):
    """
    Hook to calibrate input activations.
    Will call the observers to update the scales/zp before applying
    input QDQ in the module's forward pass.
    """
    args = args[0] if isinstance(args, tuple) else args
    calibrate_activations(module, value=args, base_name="input")

calibrate_kv_cache_input_hook(module, args, kwargs)

Hook to update inputs to attention layers when running kv_cache quantization. Will update the passed in kv_cache to singleton QuantizedKVParameterCache.

Source code in src/llmcompressor/modifiers/quantization/calibration.py
def calibrate_kv_cache_input_hook(
    module: Module, args: Any, kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
    """
    Hook to update inputs to attention layers when running
    kv_cache quantization. Will update the passed in
    kv_cache to singleton QuantizedKVParameterCache.
    """
    kv_cache = getattr(module, "kv_cache")
    kwargs["past_key_value"] = kv_cache
    kwargs["use_cache"] = False
    return args, kwargs

calibrate_kv_cache_output_hook(module, _args, _output)

Hook to update k_scale and v_scale parameters when running kv_cache quantization.

Source code in src/llmcompressor/modifiers/quantization/calibration.py
def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Tensor):
    """
    Hook to update k_scale and v_scale parameters when running kv_cache quantization.
    """
    kv_cache = getattr(module, "kv_cache")
    k_scale = kv_cache.k_scales[module.layer_idx]
    v_scale = kv_cache.v_scales[module.layer_idx]
    update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
    update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)

calibrate_output_hook(module, _args, output)

Hook to calibrate output activations. Will call the observers to update the scales/zp before applying output QDQ.

Source code in src/llmcompressor/modifiers/quantization/calibration.py
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
    """
    Hook to calibrate output activations.
    Will call the observers to update the scales/zp before applying
    output QDQ.
    """
    calibrate_activations(
        module,
        value=output,
        base_name="output",
    )
    output = forward_quantize(
        module=module,
        value=output,
        base_name="output",
        args=module.quantization_scheme.output_activations,
    )
    return output

call_observer(module, base_name, value=None)

Call a module's attached input/weight/output observer using a provided value. Update the module's scale and zp using the observer's return values.

Parameters:

Name Type Description Default
module Module

torch.nn.Module

required
base_name str

substring used to fetch the observer, scales, and zp

required
value Optional[Tensor]

torch.Tensor to be passed to the observer for activations. If base_name is "weight", then the module's weight tensor will be used

None
Source code in src/llmcompressor/modifiers/quantization/calibration.py
def call_observer(module: Module, base_name: str, value: Optional[torch.Tensor] = None):
    """
    Call a module's attached input/weight/output observer using a provided value.
    Update the module's scale and zp using the observer's return values.

    :param module: torch.nn.Module
    :param base_name: substring used to fetch the observer, scales, and zp
    :param value: torch.Tensor to be passed to the observer for activations. If
        base_name is "weight", then the module's weight tensor will be used
    """
    with align_module_device(module):
        if base_name == "weight":
            value = module.weight
            g_idx = getattr(module, "weight_g_idx", None)
        elif value is not None:
            g_idx = None
        else:
            raise ValueError(
                "Must provide a value to observe if not using weight observer"
            )

        observer = getattr(module, f"{base_name}_observer")
        updated_scale, updated_zero_point = observer(value, g_idx=g_idx)

        # update scale and zero point
        update_parameter_data(module, updated_scale, f"{base_name}_scale")
        update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")

freeze_module_quantization(module)

deletes observers when calibration is complete.

apply to full model with model.apply(freeze_module_quantization)

Parameters:

Name Type Description Default
module Module

module to freeze quantization for

required
Source code in src/llmcompressor/modifiers/quantization/calibration.py
def freeze_module_quantization(module: Module):
    """
    deletes observers when calibration is complete.

    apply to full model with `model.apply(freeze_module_quantization)`

    :param module: module to freeze quantization for
    """
    scheme = getattr(module, "quantization_scheme", None)
    if not scheme:
        # no quantization scheme nothing to do
        return

    if module.quantization_status == QuantizationStatus.FROZEN:
        # nothing to do, already frozen
        return

    # remove observers
    for name in ("input", "weight", "output"):
        obs_name = f"{name}_observer"
        if hasattr(module, obs_name):
            delattr(module, obs_name)

    # remove quantized kv_cache
    kv_cache = getattr(module, "kv_cache", None)
    if isinstance(kv_cache, QuantizedKVParameterCache):
        delattr(module, "kv_cache")

    module.quantization_status = QuantizationStatus.FROZEN

initialize_observer(module, base_name)

Initialize observer module and attach as submodule. The name of the observer is fetched from the quantization_args. The name is then used to load the observer from the registry and attached to the module. The name of the observer uses the base_name provided.

Parameters:

Name Type Description Default
module Module

torch.nn.Module that the observer is being attached to

required
base_name str

str used to name the observer attribute

required
Source code in src/llmcompressor/modifiers/quantization/calibration.py
def initialize_observer(
    module: Module,
    base_name: str,
):
    """
    Initialize observer module and attach as submodule.
    The name of the observer is fetched from the quantization_args.
    The name is then used to load the observer from the registry and attached
    to the module. The name of the observer uses the base_name provided.

    :param module: torch.nn.Module that the observer is being attached to
    :param base_name: str used to name the observer attribute

    """

    arg_name = "weights" if base_name == "weight" else f"{base_name}_activations"
    quantization_scheme = getattr(module, "quantization_scheme", None)
    if not quantization_scheme:
        # no quantization scheme nothing to do
        return

    quantization_args = getattr(quantization_scheme, arg_name, None)
    # dont need observers for dynamic
    if quantization_args is not None and not quantization_args.dynamic:
        observer = Observer.load_from_registry(
            quantization_args.observer, quantization_args=quantization_args
        )
        module.register_module(f"{base_name}_observer", observer)

initialize_quantized_kv_cache(module)

Initialize a quantized kv_cache on a module (analogous to initializing an observer) When a config specifying kv_cache quantization is applied to a model, the kv_cache args are redefined as the output_activations targeting attention modules.

This function should be called on attention modules with output_activations

Source code in src/llmcompressor/modifiers/quantization/calibration.py
def initialize_quantized_kv_cache(module: Module):
    """
    Initialize a quantized kv_cache on a module (analogous to initializing an observer)
    When a config specifying kv_cache quantization is applied to a model, the kv_cache
    args are redefined as the output_activations targeting attention modules.

    This function should be called on attention modules with output_activations
    """
    scheme: Optional[QuantizationScheme] = getattr(module, "quantization_scheme", None)
    existing_kv_cache = getattr(module, "kv_cache", None)

    if (
        scheme is None
        or not is_kv_cache_quant_scheme(scheme)
        or isinstance(existing_kv_cache, QuantizedKVParameterCache)
    ):
        return

    quantized_kv_cache = QuantizedKVParameterCache(scheme.output_activations)
    setattr(module, "kv_cache", quantized_kv_cache)

update_weight_zp_scale(module)

marks a layer as ready for calibration which activates observers to update scales and zero points on each forward pass

apply to full model with model.apply(update_weight_zp_scale)

Parameters:

Name Type Description Default
module Module

module to set for calibration

required
quantize_weights_upfront

whether to automatically run weight quantization at the start of calibration

required
Source code in src/llmcompressor/modifiers/quantization/calibration.py
def update_weight_zp_scale(module: Module):
    """
    marks a layer as ready for calibration which activates observers
    to update scales and zero points on each forward pass

    apply to full model with `model.apply(update_weight_zp_scale)`

    :param module: module to set for calibration
    :param quantize_weights_upfront: whether to automatically
       run weight quantization at the start of calibration
    """
    if getattr_chain(module, "quantization_scheme.weights", None) is None:
        return

    if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION:
        logger.warning(
            "Attempting to calibrate weights of a module not in calibration mode"
        )

    call_observer(module=module, base_name="weight")