Skip to content

llmcompressor.observers

MinMaxObserver

Bases: Observer

Implements a quantization observer that calculates scale and zero point based on the minimum and maximum values of the tensor being observed. If averaging_constant is specified, then the scales are updated using a moving average

Source code in src/llmcompressor/observers/min_max.py
@Observer.register("minmax")
class MinMaxObserver(Observer):
    """
    Implements a quantization observer that calculates scale and zero point based on the
    minimum and maximum values of the tensor being observed. If averaging_constant is
    specified, then the scales are updated using a moving average
    """

    def __init__(
        self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01
    ):
        super().__init__(quantization_args=quantization_args)

        self.min_val = {}
        self.max_val = {}
        self.averaging_constant = averaging_constant

    def calculate_qparams(
        self,
        observed: torch.Tensor,
        reduce_dims: Optional[Tuple[int]] = None,
        tensor_id: Optional[Any] = None,
    ) -> Tuple[torch.FloatTensor, torch.IntTensor]:
        """
        Updates the observed min and max using a moving average smoothed by the
        averaging_constant. Set the averaging_constant to 1.0 to disable averaging.

        :param observed: observed tensor to calculate quantization parameters for
        :param reduce_dims: optional tuple of dimensions to reduce along,
            returned scale and zero point will be shaped (1,) along the
            reduced dimensions
        :param tensor_id: Optional id if different ranges of observed tensors are
            passed, useful for sharding tensors by group_size
        :return: tuple of scale and zero point derived from the observed tensor
        """
        tensor_id = tensor_id or "default"

        if not reduce_dims:
            min_val, max_val = torch.aminmax(observed)
        else:
            min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
            max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

        # early stopping, save some computation and memory
        if self.averaging_constant == 1.0:
            return calculate_qparams(min_val, max_val, self.quantization_args)

        running_min_val = self.min_val.get(tensor_id, None)
        running_max_val = self.max_val.get(tensor_id, None)

        if running_min_val is None or running_max_val is None:
            updated_min_val = min_val
            updated_max_val = max_val
        else:
            updated_min_val = running_min_val + self.averaging_constant * (
                min_val - running_min_val
            )
            updated_max_val = running_max_val + self.averaging_constant * (
                max_val - running_max_val
            )

        self.min_val[tensor_id] = updated_min_val
        self.max_val[tensor_id] = updated_max_val

        return calculate_qparams(
            updated_min_val, updated_max_val, self.quantization_args
        )

    def get_qparams_along_dim(
        self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None
    ):
        """
        Calculate quantization parameters along the specified dimension
        """
        reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
        return self.calculate_qparams(
            observed, reduce_dims=reduce_dims, tensor_id=tensor_id
        )

    def reset(self):
        """
        Reset the state of the observer, including min and maximum values
        """
        super().reset()
        self.min_val = {}
        self.max_val = {}

calculate_qparams(observed, reduce_dims=None, tensor_id=None)

Updates the observed min and max using a moving average smoothed by the averaging_constant. Set the averaging_constant to 1.0 to disable averaging.

Parameters:

Name Type Description Default
observed Tensor

observed tensor to calculate quantization parameters for

required
reduce_dims Optional[Tuple[int]]

optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

None
tensor_id Optional[Any]

Optional id if different ranges of observed tensors are passed, useful for sharding tensors by group_size

None

Returns:

Type Description
Tuple[FloatTensor, IntTensor]

tuple of scale and zero point derived from the observed tensor

Source code in src/llmcompressor/observers/min_max.py
def calculate_qparams(
    self,
    observed: torch.Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
) -> Tuple[torch.FloatTensor, torch.IntTensor]:
    """
    Updates the observed min and max using a moving average smoothed by the
    averaging_constant. Set the averaging_constant to 1.0 to disable averaging.

    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :param tensor_id: Optional id if different ranges of observed tensors are
        passed, useful for sharding tensors by group_size
    :return: tuple of scale and zero point derived from the observed tensor
    """
    tensor_id = tensor_id or "default"

    if not reduce_dims:
        min_val, max_val = torch.aminmax(observed)
    else:
        min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
        max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

    # early stopping, save some computation and memory
    if self.averaging_constant == 1.0:
        return calculate_qparams(min_val, max_val, self.quantization_args)

    running_min_val = self.min_val.get(tensor_id, None)
    running_max_val = self.max_val.get(tensor_id, None)

    if running_min_val is None or running_max_val is None:
        updated_min_val = min_val
        updated_max_val = max_val
    else:
        updated_min_val = running_min_val + self.averaging_constant * (
            min_val - running_min_val
        )
        updated_max_val = running_max_val + self.averaging_constant * (
            max_val - running_max_val
        )

    self.min_val[tensor_id] = updated_min_val
    self.max_val[tensor_id] = updated_max_val

    return calculate_qparams(
        updated_min_val, updated_max_val, self.quantization_args
    )

get_qparams_along_dim(observed, dim, tensor_id=None)

Calculate quantization parameters along the specified dimension

Source code in src/llmcompressor/observers/min_max.py
def get_qparams_along_dim(
    self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None
):
    """
    Calculate quantization parameters along the specified dimension
    """
    reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
    return self.calculate_qparams(
        observed, reduce_dims=reduce_dims, tensor_id=tensor_id
    )

reset()

Reset the state of the observer, including min and maximum values

Source code in src/llmcompressor/observers/min_max.py
def reset(self):
    """
    Reset the state of the observer, including min and maximum values
    """
    super().reset()
    self.min_val = {}
    self.max_val = {}

MovingAverageMSEObserver

Bases: Observer

Implements a dynamic quantization observer that sets the scale and zero point based on a moving average of the mse-clipped min and max observed values

Source code in src/llmcompressor/observers/mse.py
@Observer.register("mse")
class MovingAverageMSEObserver(Observer):
    """
    Implements a dynamic quantization observer that sets the scale and
    zero point based on a moving average of the mse-clipped min and max observed values
    """

    def __init__(
        self,
        quantization_args: QuantizationArgs,
        averaging_constant: float = 0.01,
        grid: float = 100.0,
        maxshrink: float = 0.80,
        norm: float = 2.4,
    ):
        super().__init__(quantization_args=quantization_args)

        self.min_val = {}
        self.max_val = {}
        self.averaging_constant = averaging_constant
        self.grid = grid
        self.maxshrink = maxshrink
        self.norm = norm

    def calculate_mse_min_max(
        self,
        observed: Tensor,
        reduce_dims: Optional[Tuple[int]] = None,
    ):
        """
        Computes the mse-clipped min and max values of the observed tensor by
        optimizing for quantization error

        :param observed: observed tensor to calculate quantization parameters for
        :param reduce_dims: optional tuple of dimensions to reduce along,
            returned values will be shaped (1,) along the reduced dimensions
        :return: tuple of min and max values derived from the observed tensor
        """
        from compressed_tensors.quantization.lifecycle import fake_quantize

        if not reduce_dims:
            absolute_min_val, absolute_max_val = torch.aminmax(observed)
        else:
            absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
            absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

        best = torch.full_like(
            absolute_min_val, torch.finfo(absolute_min_val.dtype).max
        )
        min_val = torch.ones_like(absolute_min_val)
        max_val = torch.zeros_like(absolute_max_val)
        for i in range(int(self.maxshrink * self.grid)):
            p = 1 - i / self.grid
            shrinked_min_val = p * absolute_min_val
            shrinked_max_val = p * absolute_max_val

            candidate_scales, candidate_zero_points = calculate_qparams(
                shrinked_min_val, shrinked_max_val, self.quantization_args
            )
            q = fake_quantize(
                observed,
                candidate_scales,
                candidate_zero_points,
                self.quantization_args,
            )

            q -= observed
            q.abs_()
            q.pow_(self.norm)
            if not reduce_dims:
                err = torch.sum(q)
            else:
                err = torch.sum(q, reduce_dims, keepdims=True)

            tmp = err < best
            if torch.any(tmp):
                best[tmp] = err[tmp]
                min_val[tmp] = shrinked_min_val[tmp]
                max_val[tmp] = shrinked_max_val[tmp]
        return min_val, max_val

    def calculate_qparams(
        self,
        observed: Tensor,
        reduce_dims: Optional[Tuple[int]] = None,
        tensor_id: Optional[Any] = None,
    ) -> Tuple[FloatTensor, IntTensor]:
        """
        Updates the mse-clipped min and max values of the observed tensor using
        a moving average smoothed by the averaging_constant

        :param observed: observed tensor to calculate quantization parameters for
        :param reduce_dims: optional tuple of dimensions to reduce along,
            returned scale and zero point will be shaped (1,) along the
            reduced dimensions
        :param tensor_id: Optional id if different ranges of observed tensors are
            passed, useful for sharding tensors by group_size
        :return: tuple of scale and zero point derived from the observed tensor
        """
        min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)

        running_min_val = self.min_val.get(tensor_id, None)
        running_max_val = self.max_val.get(tensor_id, None)

        if running_min_val is None or running_max_val is None:
            updated_min_val = min_val
            updated_max_val = max_val
        else:
            updated_min_val = running_min_val + self.averaging_constant * (
                min_val - running_min_val
            )
            updated_max_val = running_max_val + self.averaging_constant * (
                max_val - running_max_val
            )

        tensor_id = tensor_id or "default"
        self.min_val[tensor_id] = updated_min_val
        self.max_val[tensor_id] = updated_max_val

        return calculate_qparams(
            updated_min_val, updated_max_val, self.quantization_args
        )

    def get_qparams_along_dim(
        self, observed, dim: int, tensor_id: Optional[Any] = None
    ):
        reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
        return self.calculate_qparams(
            observed, reduce_dims=reduce_dims, tensor_id=tensor_id
        )

    def reset(self):
        """
        Reset the state of the observer, including min and maximum values
        """
        super().reset()
        self.min_val = {}
        self.max_val = {}

calculate_mse_min_max(observed, reduce_dims=None)

Computes the mse-clipped min and max values of the observed tensor by optimizing for quantization error

Parameters:

Name Type Description Default
observed Tensor

observed tensor to calculate quantization parameters for

required
reduce_dims Optional[Tuple[int]]

optional tuple of dimensions to reduce along, returned values will be shaped (1,) along the reduced dimensions

None

Returns:

Type Description

tuple of min and max values derived from the observed tensor

Source code in src/llmcompressor/observers/mse.py
def calculate_mse_min_max(
    self,
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
):
    """
    Computes the mse-clipped min and max values of the observed tensor by
    optimizing for quantization error

    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned values will be shaped (1,) along the reduced dimensions
    :return: tuple of min and max values derived from the observed tensor
    """
    from compressed_tensors.quantization.lifecycle import fake_quantize

    if not reduce_dims:
        absolute_min_val, absolute_max_val = torch.aminmax(observed)
    else:
        absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
        absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)

    best = torch.full_like(
        absolute_min_val, torch.finfo(absolute_min_val.dtype).max
    )
    min_val = torch.ones_like(absolute_min_val)
    max_val = torch.zeros_like(absolute_max_val)
    for i in range(int(self.maxshrink * self.grid)):
        p = 1 - i / self.grid
        shrinked_min_val = p * absolute_min_val
        shrinked_max_val = p * absolute_max_val

        candidate_scales, candidate_zero_points = calculate_qparams(
            shrinked_min_val, shrinked_max_val, self.quantization_args
        )
        q = fake_quantize(
            observed,
            candidate_scales,
            candidate_zero_points,
            self.quantization_args,
        )

        q -= observed
        q.abs_()
        q.pow_(self.norm)
        if not reduce_dims:
            err = torch.sum(q)
        else:
            err = torch.sum(q, reduce_dims, keepdims=True)

        tmp = err < best
        if torch.any(tmp):
            best[tmp] = err[tmp]
            min_val[tmp] = shrinked_min_val[tmp]
            max_val[tmp] = shrinked_max_val[tmp]
    return min_val, max_val

calculate_qparams(observed, reduce_dims=None, tensor_id=None)

Updates the mse-clipped min and max values of the observed tensor using a moving average smoothed by the averaging_constant

Parameters:

Name Type Description Default
observed Tensor

observed tensor to calculate quantization parameters for

required
reduce_dims Optional[Tuple[int]]

optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

None
tensor_id Optional[Any]

Optional id if different ranges of observed tensors are passed, useful for sharding tensors by group_size

None

Returns:

Type Description
Tuple[FloatTensor, IntTensor]

tuple of scale and zero point derived from the observed tensor

Source code in src/llmcompressor/observers/mse.py
def calculate_qparams(
    self,
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
    tensor_id: Optional[Any] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    Updates the mse-clipped min and max values of the observed tensor using
    a moving average smoothed by the averaging_constant

    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :param tensor_id: Optional id if different ranges of observed tensors are
        passed, useful for sharding tensors by group_size
    :return: tuple of scale and zero point derived from the observed tensor
    """
    min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)

    running_min_val = self.min_val.get(tensor_id, None)
    running_max_val = self.max_val.get(tensor_id, None)

    if running_min_val is None or running_max_val is None:
        updated_min_val = min_val
        updated_max_val = max_val
    else:
        updated_min_val = running_min_val + self.averaging_constant * (
            min_val - running_min_val
        )
        updated_max_val = running_max_val + self.averaging_constant * (
            max_val - running_max_val
        )

    tensor_id = tensor_id or "default"
    self.min_val[tensor_id] = updated_min_val
    self.max_val[tensor_id] = updated_max_val

    return calculate_qparams(
        updated_min_val, updated_max_val, self.quantization_args
    )

reset()

Reset the state of the observer, including min and maximum values

Source code in src/llmcompressor/observers/mse.py
def reset(self):
    """
    Reset the state of the observer, including min and maximum values
    """
    super().reset()
    self.min_val = {}
    self.max_val = {}

Observer

Bases: Module, RegistryMixin

Base Observer class to be subclassed for specific implementation. Subclasses should override calculate_qparams to return a scale, zero_point pair

Source code in src/llmcompressor/observers/base.py
class Observer(Module, RegistryMixin):
    """
    Base Observer class to be subclassed for specific implementation.
    Subclasses should override `calculate_qparams` to return a scale, zero_point
    pair
    """

    def __init__(self, quantization_args: QuantizationArgs):
        self.quantization_args: QuantizationArgs = quantization_args
        super().__init__()
        self._scale = None
        self._zero_point = None
        self._num_observed_tokens = None

    @torch.no_grad()
    def forward(
        self, observed: Tensor, g_idx: Optional[Tensor] = None
    ) -> Tuple[FloatTensor, IntTensor]:
        """
        maps directly to get_qparams
        :param observed: optional observed tensor from which to calculate
            quantization parameters
        :param g_idx: optional mapping from column index to group index
        :return: tuple of scale and zero point based on last observed value
        """
        self.record_observed_tokens(observed)
        return self.get_qparams(observed=observed, g_idx=g_idx)

    def calculate_qparams(
        self,
        observed: Tensor,
        reduce_dims: Optional[Tuple[int]] = None,
    ) -> Tuple[FloatTensor, IntTensor]:
        """
        :param observed: observed tensor to calculate quantization parameters for
        :param reduce_dims: optional tuple of dimensions to reduce along,
            returned scale and zero point will be shaped (1,) along the
            reduced dimensions
        :return: tuple of scale and zero point derived from the observed tensor
        """
        raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

    def post_calculate_qparams(self) -> None:
        """
        Run any logic specific to its observers after running calculate_qparams
        """

    def get_qparams(
        self,
        observed: Optional[Tensor] = None,
        g_idx: Optional[Tensor] = None,
    ) -> Tuple[FloatTensor, IntTensor]:
        """
        Convenience function to wrap overwritten calculate_qparams
        adds support to make observed tensor optional and support for tracking latest
        calculated scale and zero point

        :param observed: optional observed tensor to calculate quantization parameters
            from
        :param g_idx: optional mapping from column index to group index
        :return: tuple of scale and zero point based on last observed value
        """
        if observed is not None:
            group_size = self.quantization_args.group_size

            if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
                # re-calculate scale and zero point, update the stored value
                self._scale, self._zero_point = self.calculate_qparams(observed)

            elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
                rows = observed.shape[0]
                columns = observed.shape[1]
                num_groups = int(ceil(columns / group_size))
                self._scale = torch.empty(
                    (rows, num_groups), dtype=observed.dtype, device=observed.device
                )
                zp_dtype = self.quantization_args.pytorch_dtype()
                self._zero_point = torch.empty(
                    (rows, num_groups), dtype=zp_dtype, device=observed.device
                )

                # support column-order (default) quantization as well as other orderings
                # such as activation ordering. Below checks if g_idx has initialized
                is_column_order = g_idx is None or -1 in g_idx
                if is_column_order:
                    group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
                else:
                    group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
                    group_sizes = group_sizes[torch.argsort(group_indices)]

                    perm = torch.argsort(g_idx)
                    observed = safe_permute(observed, perm, dim=1)

                # TODO: experiment with vectorizing for loop for performance
                end = 0
                for group_index, group_count in enumerate(group_sizes):
                    start = end
                    end = start + group_count
                    scale, zero_point = self.get_qparams_along_dim(
                        observed[:, start:end],
                        0,
                        tensor_id=group_index,
                    )

                    self._scale[:, group_index] = scale.squeeze(1)
                    self._zero_point[:, group_index] = zero_point.squeeze(1)

            elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
                # assume observed is transposed, because its the output, hence use dim 0
                self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

            elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
                # use dim 1, assume the obsersed.shape = [batch, token, hidden]
                # should be batch, token
                self._scale, self._zero_point = self.get_qparams_along_dim(
                    observed,
                    dim={0, 1},
                )

        return self._scale, self._zero_point

    def get_qparams_along_dim(
        self,
        observed,
        dim: Union[int, Iterable[int]],
        tensor_id: Optional[Any] = None,
    ):
        if isinstance(dim, int):
            dim = [dim]
        dim = set(dim)

        reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
        return self.calculate_qparams(
            observed, reduce_dims=reduce_dims, tensor_id=tensor_id
        )

    def record_observed_tokens(self, batch_tensor: Tensor):
        """
        Counts the number of tokens observed during the
        forward passes. The count is aggregated in the
        _num_observed_tokens attribute of the class.

        Note: The batch_tensor is expected to have two dimensions
            (batch_size * sequence_length, num_features). This is the
            general shape expected by the forward pass of the expert
            layers in a MOE model. If the input tensor does not have
            two dimensions, the _num_observed_tokens attribute will be set
            to None.
        """
        if not isinstance(batch_tensor, Tensor):
            raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")

        if batch_tensor.ndim != 2:
            logger.debug(
                "The input tensor is expected to have two dimensions "
                "(batch_size * sequence_length, num_features). "
                f"The input tensor has {batch_tensor.ndim} dimensions."
            )
            return

        if self._num_observed_tokens is None:
            # initialize the count
            self._num_observed_tokens = 0

        # batch_tensor (batch_size * sequence_length, num_features)
        # observed_tokens (batch_size * sequence_length)
        observed_tokens, _ = batch_tensor.shape
        self._num_observed_tokens += observed_tokens

    def reset(self):
        """
        Reset the state of the observer
        """
        self._num_observed_tokens = None
        self._scale = None
        self._zero_point = None

calculate_qparams(observed, reduce_dims=None)

Parameters:

Name Type Description Default
observed Tensor

observed tensor to calculate quantization parameters for

required
reduce_dims Optional[Tuple[int]]

optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions

None

Returns:

Type Description
Tuple[FloatTensor, IntTensor]

tuple of scale and zero point derived from the observed tensor

Source code in src/llmcompressor/observers/base.py
def calculate_qparams(
    self,
    observed: Tensor,
    reduce_dims: Optional[Tuple[int]] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    :param observed: observed tensor to calculate quantization parameters for
    :param reduce_dims: optional tuple of dimensions to reduce along,
        returned scale and zero point will be shaped (1,) along the
        reduced dimensions
    :return: tuple of scale and zero point derived from the observed tensor
    """
    raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

forward(observed, g_idx=None)

maps directly to get_qparams

Parameters:

Name Type Description Default
observed Tensor

optional observed tensor from which to calculate quantization parameters

required
g_idx Optional[Tensor]

optional mapping from column index to group index

None

Returns:

Type Description
Tuple[FloatTensor, IntTensor]

tuple of scale and zero point based on last observed value

Source code in src/llmcompressor/observers/base.py
@torch.no_grad()
def forward(
    self, observed: Tensor, g_idx: Optional[Tensor] = None
) -> Tuple[FloatTensor, IntTensor]:
    """
    maps directly to get_qparams
    :param observed: optional observed tensor from which to calculate
        quantization parameters
    :param g_idx: optional mapping from column index to group index
    :return: tuple of scale and zero point based on last observed value
    """
    self.record_observed_tokens(observed)
    return self.get_qparams(observed=observed, g_idx=g_idx)

get_qparams(observed=None, g_idx=None)

Convenience function to wrap overwritten calculate_qparams adds support to make observed tensor optional and support for tracking latest calculated scale and zero point

Parameters:

Name Type Description Default
observed Optional[Tensor]

optional observed tensor to calculate quantization parameters from

None
g_idx Optional[Tensor]

optional mapping from column index to group index

None

Returns:

Type Description
Tuple[FloatTensor, IntTensor]

tuple of scale and zero point based on last observed value

Source code in src/llmcompressor/observers/base.py
def get_qparams(
    self,
    observed: Optional[Tensor] = None,
    g_idx: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
    """
    Convenience function to wrap overwritten calculate_qparams
    adds support to make observed tensor optional and support for tracking latest
    calculated scale and zero point

    :param observed: optional observed tensor to calculate quantization parameters
        from
    :param g_idx: optional mapping from column index to group index
    :return: tuple of scale and zero point based on last observed value
    """
    if observed is not None:
        group_size = self.quantization_args.group_size

        if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
            # re-calculate scale and zero point, update the stored value
            self._scale, self._zero_point = self.calculate_qparams(observed)

        elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
            rows = observed.shape[0]
            columns = observed.shape[1]
            num_groups = int(ceil(columns / group_size))
            self._scale = torch.empty(
                (rows, num_groups), dtype=observed.dtype, device=observed.device
            )
            zp_dtype = self.quantization_args.pytorch_dtype()
            self._zero_point = torch.empty(
                (rows, num_groups), dtype=zp_dtype, device=observed.device
            )

            # support column-order (default) quantization as well as other orderings
            # such as activation ordering. Below checks if g_idx has initialized
            is_column_order = g_idx is None or -1 in g_idx
            if is_column_order:
                group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
            else:
                group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
                group_sizes = group_sizes[torch.argsort(group_indices)]

                perm = torch.argsort(g_idx)
                observed = safe_permute(observed, perm, dim=1)

            # TODO: experiment with vectorizing for loop for performance
            end = 0
            for group_index, group_count in enumerate(group_sizes):
                start = end
                end = start + group_count
                scale, zero_point = self.get_qparams_along_dim(
                    observed[:, start:end],
                    0,
                    tensor_id=group_index,
                )

                self._scale[:, group_index] = scale.squeeze(1)
                self._zero_point[:, group_index] = zero_point.squeeze(1)

        elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
            # assume observed is transposed, because its the output, hence use dim 0
            self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)

        elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
            # use dim 1, assume the obsersed.shape = [batch, token, hidden]
            # should be batch, token
            self._scale, self._zero_point = self.get_qparams_along_dim(
                observed,
                dim={0, 1},
            )

    return self._scale, self._zero_point

post_calculate_qparams()

Run any logic specific to its observers after running calculate_qparams

Source code in src/llmcompressor/observers/base.py
def post_calculate_qparams(self) -> None:
    """
    Run any logic specific to its observers after running calculate_qparams
    """

record_observed_tokens(batch_tensor)

Counts the number of tokens observed during the forward passes. The count is aggregated in the _num_observed_tokens attribute of the class.

Note: The batch_tensor is expected to have two dimensions (batch_size * sequence_length, num_features). This is the general shape expected by the forward pass of the expert layers in a MOE model. If the input tensor does not have two dimensions, the _num_observed_tokens attribute will be set to None.

Source code in src/llmcompressor/observers/base.py
def record_observed_tokens(self, batch_tensor: Tensor):
    """
    Counts the number of tokens observed during the
    forward passes. The count is aggregated in the
    _num_observed_tokens attribute of the class.

    Note: The batch_tensor is expected to have two dimensions
        (batch_size * sequence_length, num_features). This is the
        general shape expected by the forward pass of the expert
        layers in a MOE model. If the input tensor does not have
        two dimensions, the _num_observed_tokens attribute will be set
        to None.
    """
    if not isinstance(batch_tensor, Tensor):
        raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")

    if batch_tensor.ndim != 2:
        logger.debug(
            "The input tensor is expected to have two dimensions "
            "(batch_size * sequence_length, num_features). "
            f"The input tensor has {batch_tensor.ndim} dimensions."
        )
        return

    if self._num_observed_tokens is None:
        # initialize the count
        self._num_observed_tokens = 0

    # batch_tensor (batch_size * sequence_length, num_features)
    # observed_tokens (batch_size * sequence_length)
    observed_tokens, _ = batch_tensor.shape
    self._num_observed_tokens += observed_tokens

reset()

Reset the state of the observer

Source code in src/llmcompressor/observers/base.py
def reset(self):
    """
    Reset the state of the observer
    """
    self._num_observed_tokens = None
    self._scale = None
    self._zero_point = None

get_observer_token_count(module)

Parse the module and return the number of tokens observed by each module's observer.

Parameters:

Name Type Description Default
module Module

module to parse

required

Returns:

Type Description
Counter

counter with the number of tokens observed by each observer

Source code in src/llmcompressor/observers/helpers.py
def get_observer_token_count(module: torch.nn.Module) -> Counter:
    """
    Parse the module and return the number of tokens observed by
    each module's observer.

    :param module: module to parse
    :return: counter with the number of tokens observed by each observer
    """
    token_counts = Counter()
    for name, module in module.named_modules():
        if name.endswith(".input_observer"):
            token_counts[name.replace(".input_observer", "")] = (
                module._num_observed_tokens
            )
    return token_counts