Skip to content

llmcompressor.pytorch.utils

Generic code used as utilities and helpers for PyTorch

ModuleSparsificationInfo

Helper class for providing information related to torch Module parameters and the amount of sparsification applied. Includes information for pruning and quantization

Parameters:

Name Type Description Default
module Module

torch Module to analyze

required
state_dict Optional[Dict[str, Tensor]]

optional state_dict to analyze in place of the torch model. This is used when analyzing an FSDP model, where the full weights may not be accessible

None
Source code in src/llmcompressor/pytorch/utils/sparsification.py
class ModuleSparsificationInfo:
    """
    Helper class for providing information related to torch Module parameters
    and the amount of sparsification applied. Includes information for pruning
    and quantization

    :param module: torch Module to analyze
    :param state_dict: optional state_dict to analyze in place of the torch model. This
    is used when analyzing an FSDP model, where the full weights may not be accessible
    """

    def __init__(
        self, module: Module, state_dict: Optional[Dict[str, torch.Tensor]] = None
    ):
        self.module = module

        if state_dict is not None:
            # when analyzing an FSDP model, the state_dict does not differentiate
            # between trainable and non-trainable parameters
            # (e.g. it can contain buffers) this means that the
            # self.trainable_parameters may be overestimated
            self.trainable_params = state_dict
        else:
            if hasattr(module, "_hf_hook"):
                self.trainable_params = get_state_dict_offloaded_model(module)
            else:
                self.trainable_params = {
                    k: v for k, v in self.module.named_parameters() if v.requires_grad
                }

    def __str__(self):
        return json.dumps(
            {
                "params_summary": {
                    "total": self.params_total,
                    "sparse": self.params_sparse,
                    "sparsity_percent": self.params_sparse_percent,
                    "quantized": self.params_quantized,
                    "quantized_percent": self.params_quantized_percent,
                },
                "params_info": self.params_info,
            }
        )

    @property
    def params_total(self) -> int:
        """
        :return: total number of trainable parameters in the model
        """
        return sum(torch.numel(param) for param in self.trainable_params.values())

    @property
    def params_sparse(self) -> int:
        """
        :return: total number of sparse (0) trainable parameters in the model
        """
        return sum(
            round(tensor_sparsity(param).item() * torch.numel(param))
            for param in tqdm(
                self.trainable_params.values(), desc="Calculating model sparsity"
            )
        )

    @property
    def params_sparse_percent(self) -> float:
        """
        :return: percent of sparsified parameters in the entire model
        """
        return self.params_sparse / float(self.params_total) * 100

    @property
    def params_quantized(self) -> int:
        """
        :return: number of parameters across quantized layers
        """
        num_params = 0
        for name, layer in get_quantized_layers(self.module):
            if getattr(layer, "weight", None) is not None:
                num_params += torch.numel(layer.weight)
            if getattr(layer, "bias", None) is not None:
                num_params += torch.numel(layer.bias)

        return num_params

    @property
    def params_quantized_percent(self) -> float:
        """
        :return: percentage of parameters that have been quantized
        """
        return self.params_quantized / float(self.params_total) * 100

params_quantized property

Returns:

Type Description
int

number of parameters across quantized layers

params_quantized_percent property

Returns:

Type Description
float

percentage of parameters that have been quantized

params_sparse property

Returns:

Type Description
int

total number of sparse (0) trainable parameters in the model

params_sparse_percent property

Returns:

Type Description
float

percent of sparsified parameters in the entire model

params_total property

Returns:

Type Description
int

total number of trainable parameters in the model

get_linear_layers(module)

Parameters:

Name Type Description Default
module Module

the module to grab all linear layers for

required

Returns:

Type Description
Dict[str, Module]

a list of all linear layers in the module

Source code in src/llmcompressor/pytorch/utils/helpers.py
def get_linear_layers(module: Module) -> Dict[str, Module]:
    """
    :param module: the module to grab all linear layers for
    :return: a list of all linear layers in the module
    """
    return {
        name: mod for name, mod in module.named_modules() if isinstance(mod, Linear)
    }

get_quantized_layers(module)

Parameters:

Name Type Description Default
module Module

the module to get the quantized layers from

required

Returns:

Type Description
List[Tuple[str, Module]]

a list containing the names and modules of the quantized layers (Embedding, Linear, Conv2d, Conv3d)

Source code in src/llmcompressor/pytorch/utils/helpers.py
def get_quantized_layers(module: Module) -> List[Tuple[str, Module]]:
    """
    :param module: the module to get the quantized layers from
    :return: a list containing the names and modules of the quantized layers
        (Embedding, Linear, Conv2d, Conv3d)
    """

    quantized_layers = []
    for name, mod in module.named_modules():
        if hasattr(mod, "quantization_scheme"):
            weight_scheme = getattr(mod.quantization_scheme, "weights", None)
            if weight_scheme is not None and hasattr(mod, "weight"):
                quantized_layers.append((name, mod))

    return quantized_layers

set_deterministic_seeds(seed=0)

Manually seeds the numpy, random, and torch packages. Also sets torch.backends.cudnn.deterministic to True

Parameters:

Name Type Description Default
seed int

the manual seed to use. Default is 0

0
Source code in src/llmcompressor/pytorch/utils/helpers.py
def set_deterministic_seeds(seed: int = 0):
    """
    Manually seeds the numpy, random, and torch packages.
    Also sets torch.backends.cudnn.deterministic to True
    :param seed: the manual seed to use. Default is 0
    """
    numpy.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

tensor_sparsity(tens, dim=None)

Parameters:

Name Type Description Default
tens Tensor

the tensor to calculate the sparsity for

required
dim Union[None, int, List[int], Tuple[int, ...]]

the dimension(s) to split the calculations over; ex, can split over batch, channels, or combos

None

Returns:

Type Description
Tensor

the sparsity of the input tens, ie the fraction of numbers that are zero

Source code in src/llmcompressor/pytorch/utils/helpers.py
def tensor_sparsity(
    tens: Tensor, dim: Union[None, int, List[int], Tuple[int, ...]] = None
) -> Tensor:
    """
    :param tens: the tensor to calculate the sparsity for
    :param dim: the dimension(s) to split the calculations over;
        ex, can split over batch, channels, or combos
    :return: the sparsity of the input tens, ie the fraction of numbers that are zero
    """
    if dim is None:
        zeros = (tens.cpu() == 0).sum()
        total = tens.numel()

        return zeros.float() / float(total)

    if isinstance(dim, int):
        dim = [dim]

    if max(dim) >= len(tens.shape):
        raise ValueError(
            "Unsupported dim given of {} in {} for tensor shape {}".format(
                max(dim), dim, tens.shape
            )
        )

    sum_dims = [ind for ind in range(len(tens.shape)) if ind not in dim]
    zeros = (tens == 0).sum(dim=sum_dims) if sum_dims else tens == 0
    total = numpy.prod(
        [tens.shape[ind] for ind in range(len(tens.shape)) if ind not in dim]
    )

    permute_order = sorted(
        ((d, len(dim) - i - 1) for i, d in enumerate(dim)), reverse=True
    )
    permute = [d[1] for d in permute_order]

    if permute != [i for i in range(len(permute))]:
        # need to permute to get desired dimensions at the front
        zeros = zeros.permute(*permute).contiguous()

    return zeros.float() / float(total)

tensors_module_forward(tensors, module, check_feat_lab_inp=True)

Default function for calling into a model with data for a forward execution. Returns the model result. Note, if an iterable the features to be passed into the model are considered to be at index 0 and other indices are for labels.

Supported use cases: single tensor, iterable with first tensor taken as the features to pass into the model

Parameters:

Name Type Description Default
tensors Union[Tensor, Iterable[Tensor], Mapping[Any, Tensor]]

the data to be passed into the model, if an iterable the features to be passed into the model are considered to be at index 0 and other indices are for labels

required
module Module

the module to pass the data into

required
check_feat_lab_inp bool

True to check if the incoming tensors looks like it's made up of features and labels ie a tuple or list with 2 items (typical output from a data loader) and will call into the model with just the first element assuming it's the features False to not check

True

Returns:

Type Description
Any

the result of calling into the model for a forward pass

Source code in src/llmcompressor/pytorch/utils/helpers.py
def tensors_module_forward(
    tensors: Union[Tensor, Iterable[Tensor], Mapping[Any, Tensor]],
    module: Module,
    check_feat_lab_inp: bool = True,
) -> Any:
    """
    Default function for calling into a model with data for a forward execution.
    Returns the model result.
    Note, if an iterable the features to be passed into the model are considered
    to be at index 0 and other indices are for labels.

    Supported use cases: single tensor,
    iterable with first tensor taken as the features to pass into the model

    :param tensors: the data to be passed into the model, if an iterable the features
        to be passed into the model are considered to be at index 0 and other indices
        are for labels
    :param module: the module to pass the data into
    :param check_feat_lab_inp: True to check if the incoming tensors looks like
        it's made up of features and labels ie a tuple or list with 2 items
        (typical output from a data loader) and will call into the model with just
        the first element assuming it's the features False to not check
    :return: the result of calling into the model for a forward pass
    """
    if (
        (isinstance(tensors, tuple) or isinstance(tensors, List))
        and len(tensors) == 2
        and check_feat_lab_inp
    ):
        # assume if this is a list or tuple of 2 items that it is made up of
        # (features, labels) pass the features into a recursive call for the model
        return tensors_module_forward(tensors[0], module, check_feat_lab_inp=False)

    if isinstance(tensors, Tensor):
        return module(tensors)

    if isinstance(tensors, Mapping):
        return module(**tensors)

    if isinstance(tensors, Iterable):
        return module(*tensors)

    raise ValueError(
        "unrecognized type for data given of {}".format(tensors.__class__.__name__)
    )

tensors_to_device(tensors, device)

Default function for putting a tensor or collection of tensors to the proper device. Returns the tensor references after being placed on the proper device.

Supported use cases: - single tensor - Dictionary of single tensors - Dictionary of iterable of tensors - Dictionary of dictionary of tensors - Iterable of single tensors - Iterable of iterable of tensors - Iterable of dictionary of tensors

Parameters:

Name Type Description Default
tensors Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]

the tensors or collection of tensors to put onto a device

required
device str

the string representing the device to put the tensors on, ex: 'cpu', 'cuda', 'cuda:1'

required

Returns:

Type Description
Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]

the tensors or collection of tensors after being placed on the device

Source code in src/llmcompressor/pytorch/utils/helpers.py
def tensors_to_device(
    tensors: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]], device: str
) -> Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]:
    """
    Default function for putting a tensor or collection of tensors to the proper device.
    Returns the tensor references after being placed on the proper device.

    Supported use cases:
        - single tensor
        - Dictionary of single tensors
        - Dictionary of iterable of tensors
        - Dictionary of dictionary of tensors
        - Iterable of single tensors
        - Iterable of iterable of tensors
        - Iterable of dictionary of tensors

    :param tensors: the tensors or collection of tensors to put onto a device
    :param device: the string representing the device to put the tensors on,
        ex: 'cpu', 'cuda', 'cuda:1'
    :return: the tensors or collection of tensors after being placed on the device
    """
    if isinstance(tensors, Tensor):
        return tensors.to(device)

    if isinstance(tensors, OrderedDict):
        return OrderedDict(
            [(key, tensors_to_device(tens, device)) for key, tens in tensors.items()]
        )

    if isinstance(tensors, Mapping):
        return {key: tensors_to_device(tens, device) for key, tens in tensors.items()}

    if isinstance(tensors, tuple):
        return tuple(tensors_to_device(tens, device) for tens in tensors)

    if isinstance(tensors, Iterable):
        return [tensors_to_device(tens, device) for tens in tensors]

    raise ValueError(
        "unrecognized type for tensors given of {}".format(tensors.__class__.__name__)
    )

tensors_to_precision(tensors, full_precision)

Parameters:

Name Type Description Default
tensors Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]

the tensors to change the precision of

required
full_precision bool

True for full precision (float 32) and False for half (float 16)

required

Returns:

Type Description
Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]

the tensors converted to the desired precision

Source code in src/llmcompressor/pytorch/utils/helpers.py
def tensors_to_precision(
    tensors: Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]], full_precision: bool
) -> Union[Tensor, Iterable[Tensor], Dict[Any, Tensor]]:
    """
    :param tensors: the tensors to change the precision of
    :param full_precision: True for full precision (float 32) and
        False for half (float 16)
    :return: the tensors converted to the desired precision
    """
    if isinstance(tensors, Tensor):
        return tensors.float() if full_precision else tensors.half()

    if isinstance(tensors, Mapping):
        return {
            key: tensors_to_precision(tens, full_precision)
            for key, tens in tensors.items()
        }

    if isinstance(tensors, tuple):
        return tuple(tensors_to_precision(tens, full_precision) for tens in tensors)

    if isinstance(tensors, Iterable):
        return [tensors_to_precision(tens, full_precision) for tens in tensors]

    raise ValueError(
        "unrecognized type for tensors given of {}".format(tensors.__class__.__name__)
    )