Skip to content

llmcompressor.modifiers.quantization

GPTQModifier

Bases: Modifier, QuantizationMixin

Implements the GPTQ algorithm from https://arxiv.org/abs/2210.17323. This modifier uses activations to calibrate a hessian matrix, which is then used to determine optimal quantizion values and orderings for the model weights.

| Sample yaml: | test_stage: | obcq_modifiers: | GPTQModifier: | block_size: 128 | dampening_frac: 0.001 | offload_hessians: False | config_groups: | group_0: | targets: | - "Linear" | input_activations: null | output_activations: null | weights: | num_bits: 8 | type: "int" | symmetric: true | strategy: "tensor" | group_size: 128 | actorder: False

Lifecycle: - on_initialize - apply config to model - on_start - add activation calibration hooks - add gptq weight calibration hooks - on_sequential_epoch_end - quantize_weight - on_finalize - remove_hooks() - model.apply(freeze_module_quantization)

Parameters:

Name Type Description Default
sequential_targets

list of layer names to compress during GPTQ, or 'ALL' to compress every layer in the model

required
block_size

Used to determine number of columns to compress in one pass

required
dampening_frac

Amount of dampening to apply to H, as a fraction of the diagonal norm

required
offload_hessians

Set to True for decreased memory usage but increased runtime.

required
config_groups

dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized.

required
targets

list of layer names to quantize if a scheme is provided. Defaults to Linear layers

required
ignore

optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list.

required
scheme

a single quantization scheme to apply to the model. This is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level. Can also be set to a dictionary of the format preset_scheme_name: targets for example: W8A8: ['Linear'] for weight and activation 8-bit.

required
kv_cache_scheme

optional QuantizationArgs, that specify the quantization of the kv cache. If None, kv cache is not quantized. When applying kv cache quantization to transformer AutoModelForCausalLM, the kv_cache_scheme gets converted into a QuantizationScheme that: - targets the q_proj and k_proj modules of the model. The outputs of those modules are the keys and values that might be cached - quantizes the outputs of the aformentioned layers, so that keys and values are compressed before storing them in the cache There is an explicit assumption that the model contains modules with k_proj and v_proj in their names. If this is not the case and kv_cache_scheme != None, the quantization of kv cache will fail

required
Source code in src/llmcompressor/modifiers/quantization/gptq/base.py
class GPTQModifier(Modifier, QuantizationMixin):
    """
    Implements the GPTQ algorithm from https://arxiv.org/abs/2210.17323. This modifier
    uses activations to calibrate a hessian matrix, which is then used to determine
    optimal quantizion values and orderings for the model weights.

    | Sample yaml:
    | test_stage:
    |    obcq_modifiers:
    |      GPTQModifier:
    |          block_size: 128
    |          dampening_frac: 0.001
    |          offload_hessians: False
    |          config_groups:
    |            group_0:
    |                targets:
    |                  - "Linear"
    |                input_activations: null
    |                output_activations: null
    |                weights:
    |                    num_bits: 8
    |                    type: "int"
    |                    symmetric: true
    |                    strategy: "tensor"
    |                    group_size: 128
    |                    actorder: False

    Lifecycle:
        - on_initialize
            - apply config to model
        - on_start
            - add activation calibration hooks
            - add gptq weight calibration hooks
        - on_sequential_epoch_end
            - quantize_weight
        - on_finalize
            - remove_hooks()
            - model.apply(freeze_module_quantization)

    :param sequential_targets: list of layer names to compress during GPTQ, or
        '__ALL__' to compress every layer in the model
    :param block_size: Used to determine number of columns to compress in one pass
    :param dampening_frac: Amount of dampening to apply to H, as a fraction of the
        diagonal norm
    :param offload_hessians: Set to True for decreased memory usage but increased
        runtime.

    :param config_groups: dictionary specifying quantization schemes to apply to target
        modules. Modules not matching a scheme target will NOT be quantized.
    :param targets: list of layer names to quantize if a scheme is provided. Defaults
        to Linear layers
    :param ignore: optional list of module class names or submodule names to not
        quantize even if they match a target in config_groups. Defaults to empty list.
    :param scheme: a single quantization scheme to apply to the model. This is a
        dictionary that supports all keys from QuantizationScheme except targets, which
        will be set to the targets parameter set at the modifier level. Can also be set
        to a dictionary of the format `preset_scheme_name: targets` for example:
        `W8A8: ['Linear']` for weight and activation 8-bit.
    :param kv_cache_scheme: optional QuantizationArgs, that specify the
        quantization of the kv cache. If None, kv cache is not quantized.
        When applying kv cache quantization to transformer AutoModelForCausalLM,
        the kv_cache_scheme gets converted into a QuantizationScheme that:
            - targets the `q_proj` and `k_proj` modules of the model. The outputs
              of those modules are the keys and values that might be cached
            - quantizes the outputs of the aformentioned layers, so that
              keys and values are compressed before storing them in the cache
        There is an explicit assumption that the model contains modules with
        `k_proj` and `v_proj` in their names. If this is not the case
        and kv_cache_scheme != None, the quantization of kv cache will fail
    """

    # gptq modifier arguments
    sequential_update: bool = True  # DEPRECIATED
    sequential_targets: Union[str, List[str], None] = None
    block_size: int = 128
    dampening_frac: Optional[float] = 0.01
    offload_hessians: bool = False

    # private variables
    _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
    _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)
    _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)

    @field_validator("sequential_update", mode="before")
    def validate_sequential_update(cls, value: bool) -> bool:
        if not value:
            warnings.warn(
                "`sequential_update=False` is no longer supported, setting "
                "sequential_update=True",
                DeprecationWarning,
            )

        return True

    def on_initialize(self, state: State, **kwargs) -> bool:
        """
        Initialize and run the GPTQ algorithm on the current state

        :param state: session state storing input model and calibration data
        """
        # apply config to model and prepare calibration hooks
        if QuantizationMixin.has_config(self):
            QuantizationMixin.initialize_quantization(self, state.model)

        # prepare module names
        self._module_names = {m: name for name, m in state.model.named_modules()}

        return True

    def on_start(self, state: State, event: Event, **kwargs):
        self.started_ = True

        # register quantization calibration hooks
        # assume quantization has been initialized by this modifier or one before it
        QuantizationMixin.start_calibration(self, state.model)
        # Unlike qmod, do not quantize as we calibrate
        # This choice does not seem to have a meaningful impact on accuracy
        state.model.apply(disable_quantization)

        # register gptq hooks
        added_hook = False
        for module in state.model.modules():
            if getattr_chain(module, "quantization_scheme.weights", None) is not None:
                # HACK: previously, embeddings were not quantized because they were not
                # accessible by the layer compressor. For now, we manually ignore it,
                # but in the FUTURE this should be ignored by the user
                if not isinstance(module, torch.nn.Embedding):
                    self.register_hook(module, self.calibrate_module, "forward")
                    added_hook = True

        if not added_hook:
            raise ValueError(
                "GPTQModifier requires a weight quantization config be specified by "
                "this modifier or a modifier preceding it"
            )

    def on_event(self, state: State, event: Event, **kwargs):
        if event.type_ == EventType.CALIBRATION_EPOCH_START:
            if not self.started_:
                self.on_start(state, None)

        if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
            self.compress_modules()

        if event.type_ == EventType.CALIBRATION_EPOCH_END:
            self.compress_modules()

            if not self.ended_:
                self.on_end(state, None)

    def on_end(self, state: State, event: Event, **kwargs):
        """
        Finish calibrating by removing observers and calibration hooks
        """
        self.ended_ = True
        QuantizationMixin.end_calibration(self, state.model)
        self.remove_hooks()  # remove gptq hooks

    def on_finalize(self, state: State, **kwargs) -> bool:
        """
        disable the quantization observers used by the OBCQ algorithm

        :param state: session state storing input model and calibration data
        """
        if not self.ended_:
            self.on_end(state, None)

        if len(self._num_samples) > 0:
            raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

        self._hessians = dict()
        self._num_samples = dict()

        return True

    def calibrate_module(
        self,
        module: torch.nn.Module,
        args: Tuple[torch.Tensor, ...],
        _output: torch.Tensor,
    ):
        """
        Calibration hook used to accumulate the hessian of the input to the module

        :param module: module being calibrated
        :param args: inputs to the module, the first element of which is the
            cannonical input
        :param _output: uncompressed module output, unused
        """
        # Assume that first argument is the input
        inp = args[0]

        # Initialize hessian if not present
        if module not in self._num_samples:
            init_device = (
                "cpu" if self.offload_hessians else get_execution_device(module)
            )
            self._hessians[module] = make_empty_hessian(module, device=init_device)
            self._num_samples[module] = 0

        # Accumulate hessian with input with optional offloading
        with self._maybe_onload_hessian(module):
            self._hessians[module], self._num_samples[module] = accumulate_hessian(
                inp,
                module,
                self._hessians[module],
                self._num_samples[module],
            )

    def compress_modules(self):
        """
        Quantize modules which have been calibrated
        """
        for module in list(self._num_samples.keys()):
            name = self._module_names[module]
            num_samples = self._num_samples[module]
            quant_args = getattr_chain(module, "quantization_scheme.weights")

            logger.info(f"Quantizing {name} using {num_samples} samples")
            with torch.no_grad(), align_module_device(
                module
            ), self._maybe_onload_hessian(module), CompressionLogger(
                module
            ) as comp_logger:
                loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
                    module=module,
                    quant_args=quant_args,
                    hessians_dict=self._hessians,
                    blocksize=self.block_size,
                    percdamp=self.dampening_frac,
                )
                comp_logger.set_loss(loss)

            update_offload_parameter(module, "weight", quantized_weight)
            update_offload_parameter(module, "weight_scale", scale)
            update_offload_parameter(module, "weight_zero_point", zero_point)
            if g_idx is not None:
                update_offload_parameter(module, "weight_g_idx", g_idx)

            # self._hessians[module] already deleted by quantize_weight
            del self._num_samples[module]

    @contextlib.contextmanager
    def _maybe_onload_hessian(self, module: torch.nn.Module):
        if self.offload_hessians:
            device = get_execution_device(module)
            self._hessians[module] = self._hessians[module].to(device=device)

        yield

        if self.offload_hessians:
            if module in self._hessians:  # may have been deleted in context
                self._hessians[module] = self._hessians[module].to(device="cpu")

calibrate_module(module, args, _output)

Calibration hook used to accumulate the hessian of the input to the module

Parameters:

Name Type Description Default
module Module

module being calibrated

required
args Tuple[Tensor, ...]

inputs to the module, the first element of which is the cannonical input

required
_output Tensor

uncompressed module output, unused

required
Source code in src/llmcompressor/modifiers/quantization/gptq/base.py
def calibrate_module(
    self,
    module: torch.nn.Module,
    args: Tuple[torch.Tensor, ...],
    _output: torch.Tensor,
):
    """
    Calibration hook used to accumulate the hessian of the input to the module

    :param module: module being calibrated
    :param args: inputs to the module, the first element of which is the
        cannonical input
    :param _output: uncompressed module output, unused
    """
    # Assume that first argument is the input
    inp = args[0]

    # Initialize hessian if not present
    if module not in self._num_samples:
        init_device = (
            "cpu" if self.offload_hessians else get_execution_device(module)
        )
        self._hessians[module] = make_empty_hessian(module, device=init_device)
        self._num_samples[module] = 0

    # Accumulate hessian with input with optional offloading
    with self._maybe_onload_hessian(module):
        self._hessians[module], self._num_samples[module] = accumulate_hessian(
            inp,
            module,
            self._hessians[module],
            self._num_samples[module],
        )

compress_modules()

Quantize modules which have been calibrated

Source code in src/llmcompressor/modifiers/quantization/gptq/base.py
def compress_modules(self):
    """
    Quantize modules which have been calibrated
    """
    for module in list(self._num_samples.keys()):
        name = self._module_names[module]
        num_samples = self._num_samples[module]
        quant_args = getattr_chain(module, "quantization_scheme.weights")

        logger.info(f"Quantizing {name} using {num_samples} samples")
        with torch.no_grad(), align_module_device(
            module
        ), self._maybe_onload_hessian(module), CompressionLogger(
            module
        ) as comp_logger:
            loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
                module=module,
                quant_args=quant_args,
                hessians_dict=self._hessians,
                blocksize=self.block_size,
                percdamp=self.dampening_frac,
            )
            comp_logger.set_loss(loss)

        update_offload_parameter(module, "weight", quantized_weight)
        update_offload_parameter(module, "weight_scale", scale)
        update_offload_parameter(module, "weight_zero_point", zero_point)
        if g_idx is not None:
            update_offload_parameter(module, "weight_g_idx", g_idx)

        # self._hessians[module] already deleted by quantize_weight
        del self._num_samples[module]

on_end(state, event, **kwargs)

Finish calibrating by removing observers and calibration hooks

Source code in src/llmcompressor/modifiers/quantization/gptq/base.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    Finish calibrating by removing observers and calibration hooks
    """
    self.ended_ = True
    QuantizationMixin.end_calibration(self, state.model)
    self.remove_hooks()  # remove gptq hooks

on_finalize(state, **kwargs)

disable the quantization observers used by the OBCQ algorithm

Parameters:

Name Type Description Default
state State

session state storing input model and calibration data

required
Source code in src/llmcompressor/modifiers/quantization/gptq/base.py
def on_finalize(self, state: State, **kwargs) -> bool:
    """
    disable the quantization observers used by the OBCQ algorithm

    :param state: session state storing input model and calibration data
    """
    if not self.ended_:
        self.on_end(state, None)

    if len(self._num_samples) > 0:
        raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

    self._hessians = dict()
    self._num_samples = dict()

    return True

on_initialize(state, **kwargs)

Initialize and run the GPTQ algorithm on the current state

Parameters:

Name Type Description Default
state State

session state storing input model and calibration data

required
Source code in src/llmcompressor/modifiers/quantization/gptq/base.py
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    Initialize and run the GPTQ algorithm on the current state

    :param state: session state storing input model and calibration data
    """
    # apply config to model and prepare calibration hooks
    if QuantizationMixin.has_config(self):
        QuantizationMixin.initialize_quantization(self, state.model)

    # prepare module names
    self._module_names = {m: name for name, m in state.model.named_modules()}

    return True

Observer

Bases: Module, RegistryMixin

Base Observer class to be subclassed for specific implementation. Subclasses should override calculate_qparams to return a scale, zero_point pair

Source code in src/llmcompressor/observers/base.py
class Observer(Module, RegistryMixin):
    """
    Base Observer class to be subclassed for specific implementation.
    Subclasses should override `calculate_qparams` to return a scale, zero_point
    pair
    """

    def __init__(self, quantization_args: QuantizationArgs):
        self.quantization_args: QuantizationArgs = quantization_args
        super().__init__()
        self._scale = None
        self._zero_point = None
        self._num_observed_tokens = None

    @torch.no_grad()
    def forward(
        self, observed: Tensor, g_idx: Optional[Tensor] = None
    ) -> Tuple[FloatTensor, IntTensor]:
        """
        maps directly to get_qparams
        :param observed: optional observed tensor from which to calculate
            quantization parameters
        :param g_idx: optional mapping from column index to group index
        :return: tuple of scale and zero point based on last observed value
        """
        self.record_observed_tokens(observed)
        return self.get_qparams(observed=observed, g_idx=g_idx)

    def calculate_qparams(
        self,
        observed: Tensor,
        reduce_dims: Optional[Tuple[int]] = None,
    ) -> Tuple[FloatTensor, IntTensor]:
        """
        :param observed: observed tensor to calculate quantization parameters for
        :param reduce_dims: optional tuple of dimensions to reduce along,
            returned scale and zero point will be shaped (1,) along the
            reduced dimensions
        :return: tuple of scale and zero point derived from the observed tensor
        """
        raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

    def post_calculate_qparams(self) -> None:
        """
        Run any logic specific to its observers after running calculate_qparams
        """

    def get_qparams(
        self,
        observed: Optional[Tensor] = None,
        g_idx: Optional[Tensor] = None,
    ) -> Tuple[FloatTensor, IntTensor]:
        """
        Convenience function to wrap overwritten calculate_qparams
        adds support to make observed tensor optional and support for tracking latest
        calculated scale and zero point

        :param observed: optional observed tensor to calculate quantization parameters
            from
        :param g_idx: optional mapping from column index to group index
        :return: tuple of scale and zero point based on last observed value
        """
        if observed is not None:
            group_size = self.quantization_args.group_size

            if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
                # re-calculate scale and zero point, update the stored value
                self._scale, self._zero_point = self.calculate_qparams(observed)

            elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
                rows = observed.shape[0]
                columns = observed.shape[1]
                num_groups = int(ceil(columns / group_size))
                self._scale = torch.empty(
                    (rows, num_groups), dtype=observed.dtype, device=observed.device
                )
                zp_dtype = self.quantization_args.pytorch_dtype()
                self._zero_point = torch.empty(
                    (rows, num_groups), dtype=zp_dtype, device=observed.device
                )

                # support column-order (default) quantization as well as other orderings
                # such as activation ordering. Below checks if g_idx has initialized
                is_column_order = g_idx is None or -1 in g_idx
                if is_column_order:
                    group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
                else:
                    group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
                    group_sizes = group_sizes[torch.argsort(group_indices)]

                    perm = torch.argsort(g_idx)
                    observed = safe_permute(observed, perm, dim=1)

                # TODO: experiment with vectorizing for loop for performance
                end = 0
                for group_index, group_count in enumerate(group_sizes):
                    start = end
                    end = start + group_count
                    scale, zero_point = self.get_qparams_along_dim(
                        observed[:, start:end],
                        0,
                        tensor_id=group_index,
                    )

                    self._scale[:, group_index] = scale.squeeze(1)
                    self._zero_point[:, group_index] = zero_point.squeeze(1)

            elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
                # assume observed is transposed, because its the output, hence use dim 0
                self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

            elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
                # use dim 1, assume the obsersed.shape = [batch, token, hidden]
                # should be batch, token
                self._scale, self._zero_point = self.get_qparams_along_dim(
                    observed,
                    dim={0, 1},
                )

        return self._scale, self._zero_point

    def get_qparams_along_dim(
        self,
        observed,
        dim: Union[int, Iterable[int]],
        tensor_id: Optional[Any] = None,
    ):
        if isinstance(dim, int):
            dim = [dim]
        dim = set(dim)

        reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
        return self.calculate_qparams(
            observed, reduce_dims=reduce_dims, tensor_id=tensor_id
        )

    def record_observed_tokens(self, batch_tensor: Tensor):
        """
        Counts the number of tokens observed during the
        forward passes. The count is aggregated in the
        _num_observed_tokens attribute of the class.

        Note: The batch_tensor is expected to have two dimensions
            (batch_size * sequence_length, num_features). This is the
            general shape expected by the forward pass of the expert
            layers in a MOE model. If the input tensor does not have
            two dimensions, the _num_observed_tokens attribute will be set
            to None.
        """
        if not isinstance(batch_tensor, Tensor):
            raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")

        if batch_tensor.ndim != 2:
            logger.debug(
                "The input tensor is expected to have two dimensions "
                "(batch_size * sequence_length, num_features). "
                f"The input tensor has {batch_tensor.ndim} dimensions."
            )
            return

        if self._num_observed_tokens is None:
            # initialize the count
            self._num_observed_tokens = 0

        # batch_tensor (batch_size * sequence_length, num_features)
        # observed_tokens (batch_size * sequence_length)
        observed_tokens, _ = batch_tensor.shape
        self._num_observed_tokens += observed_tokens

    def reset(self):
        """
        Reset the state of the observer
        """
        self._num_observed_tokens = None
        self._scale = None
        self._zero_point = None

calculate_qparams(observed, reduce_dims=None)

Parameters:

Name Type Description Default
observed Tensor

observed tensor to calculate quantization parameters for

required
reduce_dims Optional[Tuple[int]]

optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

None

Returns:

Type Description
Tuple[FloatTensor, IntTensor]

tuple of scale and zero point derived from the observed tensor

Source code in src/llmcompressor/observers/base.py
def calculate_qparams(
    self,
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :return: tuple of scale and zero point derived from the observed tensor
    """
    raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

forward(observed, g_idx=None)

maps directly to get_qparams

Parameters:

Name Type Description Default
observed Tensor

optional observed tensor from which to calculate quantization parameters

required
g_idx Optional[Tensor]

optional mapping from column index to group index

None

Returns:

Type Description
Tuple[FloatTensor, IntTensor]

tuple of scale and zero point based on last observed value

Source code in src/llmcompressor/observers/base.py
@torch.no_grad()
def forward(
    self, observed: Tensor, g_idx: Optional[Tensor] = None
) -> Tuple[FloatTensor, IntTensor]:
    """
    maps directly to get_qparams
    :param observed: optional observed tensor from which to calculate
        quantization parameters
    :param g_idx: optional mapping from column index to group index
    :return: tuple of scale and zero point based on last observed value
    """
    self.record_observed_tokens(observed)
    return self.get_qparams(observed=observed, g_idx=g_idx)

get_qparams(observed=None, g_idx=None)

Convenience function to wrap overwritten calculate_qparams adds support to make observed tensor optional and support for tracking latest calculated scale and zero point

Parameters:

Name Type Description Default
observed Optional[Tensor]

optional observed tensor to calculate quantization parameters from

None
g_idx Optional[Tensor]

optional mapping from column index to group index

None

Returns:

Type Description
Tuple[FloatTensor, IntTensor]

tuple of scale and zero point based on last observed value

Source code in src/llmcompressor/observers/base.py
def get_qparams(
    self,
    observed: Optional[Tensor] = None,
    g_idx: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    Convenience function to wrap overwritten calculate_qparams
    adds support to make observed tensor optional and support for tracking latest
    calculated scale and zero point

    :param observed: optional observed tensor to calculate quantization parameters
        from
    :param g_idx: optional mapping from column index to group index
    :return: tuple of scale and zero point based on last observed value
    """
    if observed is not None:
        group_size = self.quantization_args.group_size

        if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
            # re-calculate scale and zero point, update the stored value
            self._scale, self._zero_point = self.calculate_qparams(observed)

        elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
            rows = observed.shape[0]
            columns = observed.shape[1]
            num_groups = int(ceil(columns / group_size))
            self._scale = torch.empty(
                (rows, num_groups), dtype=observed.dtype, device=observed.device
            )
            zp_dtype = self.quantization_args.pytorch_dtype()
            self._zero_point = torch.empty(
                (rows, num_groups), dtype=zp_dtype, device=observed.device
            )

            # support column-order (default) quantization as well as other orderings
            # such as activation ordering. Below checks if g_idx has initialized
            is_column_order = g_idx is None or -1 in g_idx
            if is_column_order:
                group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
            else:
                group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
                group_sizes = group_sizes[torch.argsort(group_indices)]

                perm = torch.argsort(g_idx)
                observed = safe_permute(observed, perm, dim=1)

            # TODO: experiment with vectorizing for loop for performance
            end = 0
            for group_index, group_count in enumerate(group_sizes):
                start = end
                end = start + group_count
                scale, zero_point = self.get_qparams_along_dim(
                    observed[:, start:end],
                    0,
                    tensor_id=group_index,
                )

                self._scale[:, group_index] = scale.squeeze(1)
                self._zero_point[:, group_index] = zero_point.squeeze(1)

        elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
            # assume observed is transposed, because its the output, hence use dim 0
            self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

        elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
            # use dim 1, assume the obsersed.shape = [batch, token, hidden]
            # should be batch, token
            self._scale, self._zero_point = self.get_qparams_along_dim(
                observed,
                dim={0, 1},
            )

    return self._scale, self._zero_point

post_calculate_qparams()

Run any logic specific to its observers after running calculate_qparams

Source code in src/llmcompressor/observers/base.py
def post_calculate_qparams(self) -> None:
    """
    Run any logic specific to its observers after running calculate_qparams
    """

record_observed_tokens(batch_tensor)

Counts the number of tokens observed during the forward passes. The count is aggregated in the _num_observed_tokens attribute of the class.

Note: The batch_tensor is expected to have two dimensions (batch_size * sequence_length, num_features). This is the general shape expected by the forward pass of the expert layers in a MOE model. If the input tensor does not have two dimensions, the _num_observed_tokens attribute will be set to None.

Source code in src/llmcompressor/observers/base.py
def record_observed_tokens(self, batch_tensor: Tensor):
    """
    Counts the number of tokens observed during the
    forward passes. The count is aggregated in the
    _num_observed_tokens attribute of the class.

    Note: The batch_tensor is expected to have two dimensions
        (batch_size * sequence_length, num_features). This is the
        general shape expected by the forward pass of the expert
        layers in a MOE model. If the input tensor does not have
        two dimensions, the _num_observed_tokens attribute will be set
        to None.
    """
    if not isinstance(batch_tensor, Tensor):
        raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")

    if batch_tensor.ndim != 2:
        logger.debug(
            "The input tensor is expected to have two dimensions "
            "(batch_size * sequence_length, num_features). "
            f"The input tensor has {batch_tensor.ndim} dimensions."
        )
        return

    if self._num_observed_tokens is None:
        # initialize the count
        self._num_observed_tokens = 0

    # batch_tensor (batch_size * sequence_length, num_features)
    # observed_tokens (batch_size * sequence_length)
    observed_tokens, _ = batch_tensor.shape
    self._num_observed_tokens += observed_tokens

reset()

Reset the state of the observer

Source code in src/llmcompressor/observers/base.py
def reset(self):
    """
    Reset the state of the observer
    """
    self._num_observed_tokens = None
    self._scale = None
    self._zero_point = None

QuantizationMixin

Bases: HooksMixin

Mixin which enables a Modifier to act as a quantization config, attching observers, calibration hooks, and compression wrappers to modifiers

Lifecycle: - on_initialize: QuantizationMixin.initialize_quantization - Attach schemes to modules - Attach observers to modules - Disable quantization until calibration starts/finishes - on_start: QuantizationMixin.start_calibration - Attach calibration hooks - Apply calibration status - Enable quantization during calibration - on_end: QuantizationMixin.end_calibration - Remove calibration hooks - Apply freeze status - Keep quantization enabled for future steps

Parameters:

Name Type Description Default
config_groups

dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized.

required
targets

list of layer names to quantize if a scheme is provided. Defaults to Linear layers

required
ignore

optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list.

required
scheme

a single quantization scheme to apply to the model. This is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level. Can also be set to a dictionary of the format preset_scheme_name: targets for example: W8A8: ['Linear'] for weight and activation 8-bit.

required
kv_cache_scheme

optional QuantizationArgs, that specify the quantization of the kv cache. If None, kv cache is not quantized. When applying kv cache quantization to transformer AutoModelForCausalLM, the kv_cache_scheme gets converted into a QuantizationScheme that: - targets the q_proj and k_proj modules of the model. The outputs of those modules are the keys and values that might be cached - quantizes the outputs of the aformentioned layers, so that keys and values are compressed before storing them in the cache There is an explicit assumption that the model contains modules with k_proj and v_proj in their names. If this is not the case and kv_cache_scheme != None, the quantization of kv cache will fail

required
Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
class QuantizationMixin(HooksMixin):
    """
    Mixin which enables a Modifier to act as a quantization config, attching observers,
    calibration hooks, and compression wrappers to modifiers

    Lifecycle:
        - on_initialize: QuantizationMixin.initialize_quantization
            - Attach schemes to modules
            - Attach observers to modules
            - Disable quantization until calibration starts/finishes
        - on_start: QuantizationMixin.start_calibration
            - Attach calibration hooks
            - Apply calibration status
            - Enable quantization during calibration
        - on_end: QuantizationMixin.end_calibration
            - Remove calibration hooks
            - Apply freeze status
            - Keep quantization enabled for future steps

    :param config_groups: dictionary specifying quantization schemes to apply to target
        modules. Modules not matching a scheme target will NOT be quantized.
    :param targets: list of layer names to quantize if a scheme is provided. Defaults
        to Linear layers
    :param ignore: optional list of module class names or submodule names to not
        quantize even if they match a target in config_groups. Defaults to empty list.
    :param scheme: a single quantization scheme to apply to the model. This is a
        dictionary that supports all keys from QuantizationScheme except targets, which
        will be set to the targets parameter set at the modifier level. Can also be set
        to a dictionary of the format `preset_scheme_name: targets` for example:
        `W8A8: ['Linear']` for weight and activation 8-bit.
    :param kv_cache_scheme: optional QuantizationArgs, that specify the
        quantization of the kv cache. If None, kv cache is not quantized.
        When applying kv cache quantization to transformer AutoModelForCausalLM,
        the kv_cache_scheme gets converted into a QuantizationScheme that:
            - targets the `q_proj` and `k_proj` modules of the model. The outputs
              of those modules are the keys and values that might be cached
            - quantizes the outputs of the aformentioned layers, so that
              keys and values are compressed before storing them in the cache
        There is an explicit assumption that the model contains modules with
        `k_proj` and `v_proj` in their names. If this is not the case
        and kv_cache_scheme != None, the quantization of kv cache will fail
    """

    config_groups: Optional[Dict[str, QuantizationScheme]] = None
    targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
    ignore: List[str] = Field(default_factory=list)
    scheme: Optional[Union[str, Dict[str, Any]]] = None
    kv_cache_scheme: Optional[QuantizationArgs] = None

    _calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set)

    @field_validator("targets", mode="before")
    def validate_targets(cls, value: Union[str, List[str]]) -> List[str]:
        if isinstance(value, str):
            return [value]

        return value

    @field_validator("scheme", mode="before")
    def validate_scheme(
        cls, value: Optional[Union[str, Dict[str, Any]]]
    ) -> Optional[Union[str, Dict[str, Any]]]:
        if isinstance(value, str) and not is_preset_scheme(value):
            raise ValueError(
                "`scheme` must either be a preset scheme name or a dictionary "
                "of preset scheme names"
            )

        if isinstance(value, dict):
            for scheme_name in value.keys():
                cls.validate_scheme(scheme_name)

            for key, target in value.items():
                value[key] = cls.validate_targets(target)

        return value

    def initialize_quantization(self, model: torch.nn.Module):
        """
        Attach quantization schemes and observers to modules in the model according to
        the quantization config specified on this modifier

        :param model: model to attach schemes and observers to
        """
        reset_quantization_status(model)  # reset any previously applied qconfigs

        # apply scheme and status to model
        config = self.resolve_quantization_config()
        apply_quantization_config(model, config)

        # apply observers, disable quantization until calibration
        model.apply(self._initialize_observers)
        model.apply(disable_quantization)

    def start_calibration(self, model: torch.nn.Module):
        """
        Register activation calibration hooks (including kv_cache quantization) and
        enable quantization as we calibrate

        :param model: model to prepare for calibration
        """
        self._calibration_hooks = self._initialize_hooks(model)
        model.apply(apply_calibration_status)
        model.apply(enable_quantization)  # quantize at the same time as calibrate

    def end_calibration(self, model: torch.nn.Module):
        """
        Remove calibration hooks and set the model status to frozen. Keep quantization
        enabled for future operations

        :param model: model to end calibration for
        """
        self.remove_hooks(self._calibration_hooks)
        model.apply(freeze_module_quantization)  # remove observers
        model.apply(enable_quantization)  # keep quantization enabled

    def has_config(self) -> bool:
        """
        Determine if the user has specified a quantization config on this modifier
        """
        return not (
            self.config_groups is None
            and self.targets == ["Linear"]
            and self.ignore == []
            and self.scheme is None
            and self.kv_cache_scheme is None
        )

    def resolve_quantization_config(self) -> QuantizationConfig:
        """
        Returns the quantization config specified by this modifier
        """
        scheme = self.scheme
        targets = self.targets
        config_groups = self.config_groups
        kv_cache_scheme = self.kv_cache_scheme
        ignore = self.ignore

        if scheme is not None and config_groups is not None:
            raise ValueError("Please specify either `scheme` or `config_groups`")

        if scheme is not None:
            # takes precedence over config_groups

            if isinstance(scheme, str) and is_preset_scheme(scheme):
                # attach targets to scheme
                scheme = {scheme: targets}

            config_groups = {}
            for idx, key in enumerate(scheme.keys()):
                if is_preset_scheme(key):
                    scheme = preset_name_to_scheme(key, scheme[key])
                else:
                    scheme = QuantizationScheme.model_validate(
                        {"targets": scheme[key], **scheme}
                    )

                group_name = f"group_{idx}"
                config_groups[group_name] = scheme

        if config_groups is None or len(config_groups) == 0:
            default_quant_scheme = QuantizationScheme(targets=targets)
            config_groups = {"group_0": default_quant_scheme}

        return QuantizationConfig(
            config_groups=config_groups,
            kv_cache_scheme=kv_cache_scheme,
            quantization_status=QuantizationStatus.INITIALIZED,
            ignore=ignore,
        )

    def _initialize_observers(self, module: torch.nn.Module):
        if not hasattr(module, "quantization_scheme"):
            return

        scheme: QuantizationScheme = module.quantization_scheme
        input = scheme.input_activations and not scheme.input_activations.dynamic
        weight = scheme.weights is not None
        output = scheme.output_activations and not scheme.output_activations.dynamic
        is_attention = is_attention_module(module)

        # input activations
        if input:
            initialize_observer(module, base_name="input")

        # weight observers (used by `update_weight_zp_scale` or child modifier)
        if weight:
            initialize_observer(module, base_name="weight")

        # kv_cache activations. Within `apply_quantization_config`, the config is
        # modified to use attention output quantization if a kv_cache_scheme exists
        if is_attention and output:
            initialize_quantized_kv_cache(module)

        # output activations
        elif output:
            initialize_observer(module, base_name="output")

    def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
        hooks = set()
        for module in model.modules():
            if not hasattr(module, "quantization_scheme"):
                continue

            scheme: QuantizationScheme = module.quantization_scheme
            input = scheme.input_activations and not scheme.input_activations.dynamic
            output = scheme.output_activations and not scheme.output_activations.dynamic
            is_attention = is_attention_module(module)

            # input activations
            if input:
                hooks.add(
                    self.register_hook(module, calibrate_input_hook, "forward_pre")
                )

            # kv_cache activations. Within `apply_quantization_config`, the config is
            # modified to use attention output quantization if a kv_cache_scheme exists
            if is_attention and output:
                hooks.add(
                    self.register_hook(
                        module,
                        calibrate_kv_cache_input_hook,
                        "forward_pre",
                        with_kwargs=True,
                    )
                )
                hooks.add(
                    self.register_hook(
                        module, calibrate_kv_cache_output_hook, "forward"
                    )
                )

            # output activations
            elif output:
                hooks.add(self.register_hook(module, calibrate_output_hook, "forward"))

        return hooks

end_calibration(model)

Remove calibration hooks and set the model status to frozen. Keep quantization enabled for future operations

Parameters:

Name Type Description Default
model Module

model to end calibration for

required
Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def end_calibration(self, model: torch.nn.Module):
    """
    Remove calibration hooks and set the model status to frozen. Keep quantization
    enabled for future operations

    :param model: model to end calibration for
    """
    self.remove_hooks(self._calibration_hooks)
    model.apply(freeze_module_quantization)  # remove observers
    model.apply(enable_quantization)  # keep quantization enabled

has_config()

Determine if the user has specified a quantization config on this modifier

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def has_config(self) -> bool:
    """
    Determine if the user has specified a quantization config on this modifier
    """
    return not (
        self.config_groups is None
        and self.targets == ["Linear"]
        and self.ignore == []
        and self.scheme is None
        and self.kv_cache_scheme is None
    )

initialize_quantization(model)

Attach quantization schemes and observers to modules in the model according to the quantization config specified on this modifier

Parameters:

Name Type Description Default
model Module

model to attach schemes and observers to

required
Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def initialize_quantization(self, model: torch.nn.Module):
    """
    Attach quantization schemes and observers to modules in the model according to
    the quantization config specified on this modifier

    :param model: model to attach schemes and observers to
    """
    reset_quantization_status(model)  # reset any previously applied qconfigs

    # apply scheme and status to model
    config = self.resolve_quantization_config()
    apply_quantization_config(model, config)

    # apply observers, disable quantization until calibration
    model.apply(self._initialize_observers)
    model.apply(disable_quantization)

resolve_quantization_config()

Returns the quantization config specified by this modifier

Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def resolve_quantization_config(self) -> QuantizationConfig:
    """
    Returns the quantization config specified by this modifier
    """
    scheme = self.scheme
    targets = self.targets
    config_groups = self.config_groups
    kv_cache_scheme = self.kv_cache_scheme
    ignore = self.ignore

    if scheme is not None and config_groups is not None:
        raise ValueError("Please specify either `scheme` or `config_groups`")

    if scheme is not None:
        # takes precedence over config_groups

        if isinstance(scheme, str) and is_preset_scheme(scheme):
            # attach targets to scheme
            scheme = {scheme: targets}

        config_groups = {}
        for idx, key in enumerate(scheme.keys()):
            if is_preset_scheme(key):
                scheme = preset_name_to_scheme(key, scheme[key])
            else:
                scheme = QuantizationScheme.model_validate(
                    {"targets": scheme[key], **scheme}
                )

            group_name = f"group_{idx}"
            config_groups[group_name] = scheme

    if config_groups is None or len(config_groups) == 0:
        default_quant_scheme = QuantizationScheme(targets=targets)
        config_groups = {"group_0": default_quant_scheme}

    return QuantizationConfig(
        config_groups=config_groups,
        kv_cache_scheme=kv_cache_scheme,
        quantization_status=QuantizationStatus.INITIALIZED,
        ignore=ignore,
    )

start_calibration(model)

Register activation calibration hooks (including kv_cache quantization) and enable quantization as we calibrate

Parameters:

Name Type Description Default
model Module

model to prepare for calibration

required
Source code in src/llmcompressor/modifiers/quantization/quantization/mixin.py
def start_calibration(self, model: torch.nn.Module):
    """
    Register activation calibration hooks (including kv_cache quantization) and
    enable quantization as we calibrate

    :param model: model to prepare for calibration
    """
    self._calibration_hooks = self._initialize_hooks(model)
    model.apply(apply_calibration_status)
    model.apply(enable_quantization)  # quantize at the same time as calibrate

QuantizationModifier

Bases: Modifier, QuantizationMixin

Enables post training quantization (PTQ) and quantization aware training (QAT) for a given module or its submodules. After calibration (PTQ) or the start epoch (QAT), the specified module(s) forward pass will emulate quantized execution and the modifier will be enabled until training is completed.

Parameters:

Name Type Description Default
config_groups

dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized.

required
targets

list of layer names to quantize if a scheme is provided. Defaults to Linear layers

required
ignore

optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list.

required
scheme

a single quantization scheme to apply to the model. This is a dictionary that supports all keys from QuantizationScheme except targets, which will be set to the targets parameter set at the modifier level. Can also be set to a dictionary of the format preset_scheme_name: targets for example: W8A8: ['Linear'] for weight and activation 8-bit.

required
kv_cache_scheme

optional QuantizationArgs, that specify the quantization of the kv cache. If None, kv cache is not quantized. When applying kv cache quantization to transformer AutoModelForCausalLM, the kv_cache_scheme gets converted into a QuantizationScheme that: - targets the q_proj and k_proj modules of the model. The outputs of those modules are the keys and values that might be cached - quantizes the outputs of the aformentioned layers, so that keys and values are compressed before storing them in the cache There is an explicit assumption that the model contains modules with k_proj and v_proj in their names. If this is not the case and kv_cache_scheme != None, the quantization of kv cache will fail

required
Source code in src/llmcompressor/modifiers/quantization/quantization/base.py
class QuantizationModifier(Modifier, QuantizationMixin):
    """
    Enables post training quantization (PTQ) and quantization aware training (QAT) for a
    given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
    the specified module(s) forward pass will emulate quantized execution and the
    modifier will be enabled until training is completed.

    :param config_groups: dictionary specifying quantization schemes to apply to target
        modules. Modules not matching a scheme target will NOT be quantized.
    :param targets: list of layer names to quantize if a scheme is provided. Defaults
        to Linear layers
    :param ignore: optional list of module class names or submodule names to not
        quantize even if they match a target in config_groups. Defaults to empty list.
    :param scheme: a single quantization scheme to apply to the model. This is a
        dictionary that supports all keys from QuantizationScheme except targets, which
        will be set to the targets parameter set at the modifier level. Can also be set
        to a dictionary of the format `preset_scheme_name: targets` for example:
        `W8A8: ['Linear']` for weight and activation 8-bit.
    :param kv_cache_scheme: optional QuantizationArgs, that specify the
        quantization of the kv cache. If None, kv cache is not quantized.
        When applying kv cache quantization to transformer AutoModelForCausalLM,
        the kv_cache_scheme gets converted into a QuantizationScheme that:
            - targets the `q_proj` and `k_proj` modules of the model. The outputs
              of those modules are the keys and values that might be cached
            - quantizes the outputs of the aformentioned layers, so that
              keys and values are compressed before storing them in the cache
        There is an explicit assumption that the model contains modules with
        `k_proj` and `v_proj` in their names. If this is not the case
        and kv_cache_scheme != None, the quantization of kv cache will fail
    """

    def on_initialize(self, state: State, **kwargs) -> bool:
        """
        Prepare to calibrate activations and weights

        According to the quantization config, a quantization scheme is attached to each
        targeted module. The module's forward call is also overwritten to perform
        quantization to inputs, weights, and outputs.

        Then, according to the module's quantization scheme, observers and calibration
        hooks are added. These hooks are disabled until the modifier starts.
        """
        if not QuantizationMixin.has_config(self):
            raise ValueError(
                "QuantizationModifier requires that quantization fields be specified"
            )
        QuantizationMixin.initialize_quantization(self, state.model)

        return True

    def on_start(self, state: State, event: Event, **kwargs):
        """
        Begin calibrating activations and weights. Calibrate weights only once on start
        """
        self.started_ = True
        QuantizationMixin.start_calibration(self, state.model)

        modules = list(state.model.modules())
        for module in tqdm.tqdm(modules, desc="Calibrating weights"):
            update_weight_zp_scale(module)

    def on_event(self, state: State, event: Event, **kwargs):
        if event.type_ == EventType.CALIBRATION_EPOCH_START:
            if not self.started_:
                self.on_start(state, None)

        if event.type_ == EventType.CALIBRATION_EPOCH_END:
            if not self.ended_:
                self.on_end(state, None)

    def on_end(self, state: State, event: Event, **kwargs):
        """
        Finish calibrating by removing observers and calibration hooks
        """
        self.ended_ = True
        QuantizationMixin.end_calibration(
            self, state.model
        )  # keep quantization enabled

    def on_finalize(self, state: State, **kwargs) -> bool:
        if not self.ended_:
            self.on_end(state, None)

on_end(state, event, **kwargs)

Finish calibrating by removing observers and calibration hooks

Source code in src/llmcompressor/modifiers/quantization/quantization/base.py
def on_end(self, state: State, event: Event, **kwargs):
    """
    Finish calibrating by removing observers and calibration hooks
    """
    self.ended_ = True
    QuantizationMixin.end_calibration(
        self, state.model
    )  # keep quantization enabled

on_initialize(state, **kwargs)

Prepare to calibrate activations and weights

According to the quantization config, a quantization scheme is attached to each targeted module. The module's forward call is also overwritten to perform quantization to inputs, weights, and outputs.

Then, according to the module's quantization scheme, observers and calibration hooks are added. These hooks are disabled until the modifier starts.

Source code in src/llmcompressor/modifiers/quantization/quantization/base.py
def on_initialize(self, state: State, **kwargs) -> bool:
    """
    Prepare to calibrate activations and weights

    According to the quantization config, a quantization scheme is attached to each
    targeted module. The module's forward call is also overwritten to perform
    quantization to inputs, weights, and outputs.

    Then, according to the module's quantization scheme, observers and calibration
    hooks are added. These hooks are disabled until the modifier starts.
    """
    if not QuantizationMixin.has_config(self):
        raise ValueError(
            "QuantizationModifier requires that quantization fields be specified"
        )
    QuantizationMixin.initialize_quantization(self, state.model)

    return True

on_start(state, event, **kwargs)

Begin calibrating activations and weights. Calibrate weights only once on start

Source code in src/llmcompressor/modifiers/quantization/quantization/base.py
def on_start(self, state: State, event: Event, **kwargs):
    """
    Begin calibrating activations and weights. Calibrate weights only once on start
    """
    self.started_ = True
    QuantizationMixin.start_calibration(self, state.model)

    modules = list(state.model.modules())
    for module in tqdm.tqdm(modules, desc="Calibrating weights"):
        update_weight_zp_scale(module)

QuantizedKVParameterCache

Bases: DynamicCache

Quantized KV cache used in the forward call based on HF's dynamic cache. Quantization strategy (tensor, group, channel) set from Quantization arg's strategy Singleton, so that the same cache gets reused in all forward call of self_attn. Each time forward is called, .update() is called, and ._quantize(), ._dequantize() gets called appropriately. The size of tensor is [batch_size, num_heads, seq_len - residual_length, head_dim].

Triggered by adding kv_cache_scheme in the recipe.

Example:

```python3 recipe = ''' quant_stage: quant_modifiers: QuantizationModifier: kv_cache_scheme: num_bits: 8 type: float strategy: tensor dynamic: false symmetric: true '''

Source code in src/llmcompressor/modifiers/quantization/cache.py
class QuantizedKVParameterCache(DynamicCache):
    """
    Quantized KV cache used in the forward call based on HF's dynamic cache.
    Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
    Singleton, so that the same cache gets reused in all forward call of self_attn.
    Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
     gets called appropriately.
    The size of tensor is
     `[batch_size, num_heads, seq_len - residual_length, head_dim]`.


    Triggered by adding kv_cache_scheme in the recipe.

    Example:

    ```python3
    recipe = '''
    quant_stage:
        quant_modifiers:
            QuantizationModifier:
                kv_cache_scheme:
                    num_bits: 8
                    type: float
                    strategy: tensor
                    dynamic: false
                    symmetric: true
    '''

    """

    _instance = None
    _initialized = False

    def __new__(cls, *args, **kwargs):
        """Singleton"""
        if cls._instance is None:
            cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
        return cls._instance

    def __init__(self, quantization_args: QuantizationArgs):
        if not self._initialized:
            super().__init__()

            self.quantization_args = quantization_args

            self.k_observers: List[Observer] = []
            self.v_observers: List[Observer] = []

            # each index corresponds to layer_idx of the attention layer
            self.k_scales: List[Tensor] = []
            self.v_scales: List[Tensor] = []

            self.k_zps: List[Tensor] = []
            self.v_zps: List[Tensor] = []

            self._initialized = True

    def update(
        self,
        key_states: Tensor,
        value_states: Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Get the k_scale and v_scale and output the
         fakequant-ed key_states and value_states
        """

        if len(self.k_observers) <= layer_idx:
            k_observer_name = self.quantization_args.observer
            k_observer = Observer.load_from_registry(
                k_observer_name, quantization_args=self.quantization_args
            )
            v_observer_name = self.quantization_args.observer
            v_observer = Observer.load_from_registry(
                v_observer_name, quantization_args=self.quantization_args
            )

            # NOTE: User may ignore some layers in configuration,
            # meaning len(self.k_observers) <= layer_idx-1
            # Must account for that case by padding list so that
            # index of lists corresponds to layer_idx
            _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
            _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)

        q_key_states = self._quantize(
            key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
        )
        q_value_states = self._quantize(
            value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
        )

        qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
        qdq_value_states = self._dequantize(
            q_value_states, KVCacheScaleType.VALUE, layer_idx
        )

        keys_to_return, values_to_return = qdq_key_states, qdq_value_states

        return keys_to_return, values_to_return

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """
        Returns the sequence length of the cached states.
        A layer index can be optionally passed.
        """
        if len(self.key_cache) <= layer_idx:
            return 0
        # since we cannot get the seq_length of each layer directly and
        # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
        # this is a hack to get the actual seq_length for the given layer_idx
        # this part of code otherwise fails when used to
        # verify attn_weight shape in some models
        return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

    def reset_states(self):
        """reset the kv states (used in calibration)"""
        self.key_cache: List[Tensor] = []
        self.value_cache: List[Tensor] = []
        # Used in `generate` to keep tally of how many tokens the cache has seen
        self._seen_tokens = 0
        self._quantized_key_cache: List[Tensor] = []
        self._quantized_value_cache: List[Tensor] = []

    def reset(self):
        """
        Reset the instantiation, create new instance on init
        """
        QuantizedKVParameterCache._instance = None
        QuantizedKVParameterCache._initialized = False

    def _quantize(self, tensor, kv_type, layer_idx):
        """Quantizes a key/value using a defined quantization method."""
        from compressed_tensors.quantization.lifecycle.forward import quantize

        if kv_type == KVCacheScaleType.KEY:  # key type
            observer = self.k_observers[layer_idx]
            scales = self.k_scales
            zps = self.k_zps
        else:
            assert kv_type == KVCacheScaleType.VALUE
            observer = self.v_observers[layer_idx]
            scales = self.v_scales
            zps = self.v_zps

        scale, zp = observer(tensor)
        _pad_and_append_at_idx_(scales, layer_idx, scale)
        _pad_and_append_at_idx_(zps, layer_idx, zp)

        q_tensor = quantize(
            x=tensor,
            scale=scale,
            zero_point=zp,
            args=self.quantization_args,
        )
        return q_tensor

    def _dequantize(self, qtensor, kv_type, layer_idx):
        """Dequantizes back the tensor that was quantized by `self._quantize()`"""
        from compressed_tensors.quantization.lifecycle.forward import dequantize

        if kv_type == KVCacheScaleType.KEY:
            scale = self.k_scales[layer_idx]
            zp = self.k_zps[layer_idx]
        else:
            assert kv_type == KVCacheScaleType.VALUE
            scale = self.v_scales[layer_idx]
            zp = self.v_zps[layer_idx]

        qdq_tensor = dequantize(
            x_q=qtensor,
            scale=scale,
            zero_point=zp,
            args=self.quantization_args,
        )
        return qdq_tensor

__new__(*args, **kwargs)

Singleton

Source code in src/llmcompressor/modifiers/quantization/cache.py
def __new__(cls, *args, **kwargs):
    """Singleton"""
    if cls._instance is None:
        cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
    return cls._instance

get_seq_length(layer_idx=0)

Returns the sequence length of the cached states. A layer index can be optionally passed.

Source code in src/llmcompressor/modifiers/quantization/cache.py
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
    """
    Returns the sequence length of the cached states.
    A layer index can be optionally passed.
    """
    if len(self.key_cache) <= layer_idx:
        return 0
    # since we cannot get the seq_length of each layer directly and
    # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
    # this is a hack to get the actual seq_length for the given layer_idx
    # this part of code otherwise fails when used to
    # verify attn_weight shape in some models
    return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1

reset()

Reset the instantiation, create new instance on init

Source code in src/llmcompressor/modifiers/quantization/cache.py
def reset(self):
    """
    Reset the instantiation, create new instance on init
    """
    QuantizedKVParameterCache._instance = None
    QuantizedKVParameterCache._initialized = False

reset_states()

reset the kv states (used in calibration)

Source code in src/llmcompressor/modifiers/quantization/cache.py
def reset_states(self):
    """reset the kv states (used in calibration)"""
    self.key_cache: List[Tensor] = []
    self.value_cache: List[Tensor] = []
    # Used in `generate` to keep tally of how many tokens the cache has seen
    self._seen_tokens = 0
    self._quantized_key_cache: List[Tensor] = []
    self._quantized_value_cache: List[Tensor] = []

update(key_states, value_states, layer_idx, cache_kwargs=None)

Get the k_scale and v_scale and output the fakequant-ed key_states and value_states

Source code in src/llmcompressor/modifiers/quantization/cache.py
def update(
    self,
    key_states: Tensor,
    value_states: Tensor,
    layer_idx: int,
    cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Get the k_scale and v_scale and output the
     fakequant-ed key_states and value_states
    """

    if len(self.k_observers) <= layer_idx:
        k_observer_name = self.quantization_args.observer
        k_observer = Observer.load_from_registry(
            k_observer_name, quantization_args=self.quantization_args
        )
        v_observer_name = self.quantization_args.observer
        v_observer = Observer.load_from_registry(
            v_observer_name, quantization_args=self.quantization_args
        )

        # NOTE: User may ignore some layers in configuration,
        # meaning len(self.k_observers) <= layer_idx-1
        # Must account for that case by padding list so that
        # index of lists corresponds to layer_idx
        _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
        _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)

    q_key_states = self._quantize(
        key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
    )
    q_value_states = self._quantize(
        value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
    )

    qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
    qdq_value_states = self._dequantize(
        q_value_states, KVCacheScaleType.VALUE, layer_idx
    )

    keys_to_return, values_to_return = qdq_key_states, qdq_value_states

    return keys_to_return, values_to_return