Skip to content

llmcompressor.transformers

Tools for integrating LLM Compressor with transformers training flows

SessionManagerMixIn

Mix-In class to extend the Hugging Face Trainer class to support LLM Compressor recipes for one-shot and finetuning flows.

Parameters:

Name Type Description Default
recipe str

path to recipe file to apply during training

required
recipe_args Optional[Union[Dict[str, Any], str]]

additional kwargs to use for evaluating recipe

None
dataset_args Optional[DatasetArguments]

kwargs for configuring dataset loading

None
teacher Optional[Union[Module, str]]

optional teacher model to use for distillation

None
Source code in src/llmcompressor/transformers/finetune/session_mixin.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
class SessionManagerMixIn:
    """
    Mix-In class to extend the Hugging Face Trainer class to support LLM Compressor
    recipes for one-shot and finetuning flows.

    :param recipe: path to recipe file to apply during training
    :param recipe_args: additional kwargs to use for evaluating recipe
    :param dataset_args: kwargs for configuring dataset loading
    :param teacher: optional teacher model to use for distillation
    """

    def __init__(
        self,
        recipe: str,
        model_args: "ModelArguments",
        dataset_args: Optional["DatasetArguments"] = None,
        teacher: Optional[Union[Module, str]] = None,
        recipe_args: Optional[Union[Dict[str, Any], str]] = None,
        **kwargs,
    ):
        self.recipe = recipe
        self.recipe_args = recipe_args
        self.model_args = model_args
        self.teacher = teacher

        # parse training and metadata args
        training_args = kwargs.get("args")

        self.metadata = None
        if training_args is not None:
            # trl_sft_trainer pathway. Both training_args and dataset_args
            # have `max_seq_length` which causes collision error. This is the
            # only shared parameter, where training arg is `TRLSFTConfig` that
            # inherits HuggingFace's `TrainingArguments`
            training_args_dict = training_args.to_dict()
            if "max_seq_length" in training_args_dict:
                training_args_dict["training_args_max_seq_length"] = (
                    training_args_dict.pop("max_seq_length")
                )
                logger.warning(
                    "Detected `max_seq_length` in both dataset_args ",
                    "and training_args. This is expected for TRL in distillation. ",
                    "Updating metadata to `training_args_max_seq_length`",
                )

            self.metadata = self._extract_metadata(
                metadata_args=METADATA_ARGS,
                training_args_dict=training_args_dict,
                dataset_args_dict=asdict(dataset_args) if dataset_args else {},
            )

        # setup metrics and session
        self.logger_manager = LoggerManager(log_python=False)
        create_session()

        # call Trainer initialization
        super().__init__(**kwargs)
        self.accelerator.wait_for_everyone()

        # setup callbacks and loss
        self.optim_callbacks = TrainingLoopCallbacks(self)
        self.callback_handler.add_callback(self.optim_callbacks)
        self.callback_disable_fp16 = DisableHalfPrecisionCallback(self)
        self.callback_handler.add_callback(self.callback_disable_fp16)
        self.criterion = torch.nn.CrossEntropyLoss()

        model_signature = inspect.signature(self.model.forward)
        self._signature_columns = list(model_signature.parameters.keys())

        if self.teacher is not None and teacher not in ("disable", "self"):
            teacher_signature = inspect.signature(self.teacher.forward)
            self._teacher_signature_columns = list(teacher_signature.parameters.keys())
        else:
            self._teacher_signature_columns = None

        if self.is_fsdp_enabled:
            self._prepare_model_for_fsdp()

        if dataset_args is not None:
            self.min_tokens_per_module = dataset_args.min_tokens_per_module

    def initialize_session(
        self,
        epoch: float,
        checkpoint: Optional[str] = None,
        stage: Optional[str] = None,
    ):
        """
        Initialize the CompressionSession from the specified epoch, evaluates the recipe
        and initialized the modifiers for the training session

        :param epoch: Epoch to initialize session from, usually 0 unless loading
        from a checkpoint
        :param checkpoint: Optional checkpoint to initialize from to continue training
        :param stage: Optional stage of recipe to run, or None to run all stages
        """
        session = active_session()
        if session.lifecycle.initialized_ or session.lifecycle.finalized:
            return False

        train_data = self.get_train_dataloader()

        self.accelerator.wait_for_everyone()
        with summon_full_params_context(self.model, offload_to_cpu=True):
            active_session().initialize(
                recipe=self.recipe,
                recipe_stage=stage,
                recipe_args=self.recipe_args,
                model=self.model,
                teacher_model=self.teacher,  # TODO: what about for self/disable?
                train_data=train_data,
                start=epoch,
                copy_data=False,
                attach_optim_callbacks=True,
                fsdp_active=self.is_fsdp_enabled,
                metadata=self.metadata,
            )

        self.accelerator.wait_for_everyone()
        model = get_session_model()
        self.model_wrapped = self.model = model

        if self.recipe is None:
            logger.warning(
                "No training recipe was provided, finetuning will be run "
                "without event callbacks to LLM Compressor. To supply a recipe "
                "pass a yaml file or string to the `recipe` argument."
            )

        torch.cuda.empty_cache()

    def finalize_session(self):
        """
        Wrap up training by finalizing all modifiers initialized in the current session
        """
        session = active_session()
        if not session.lifecycle.initialized_ or session.lifecycle.finalized:
            return False

        with summon_full_params_context(self.model, offload_to_cpu=True):
            # in order to update each layer we need to gathers all its parameters
            active_session().finalize()
        logger.info("Finalized LLM Compressor session")
        model = get_session_model()
        self.model = model
        torch.cuda.empty_cache()

    def create_optimizer(self):
        """
        Override the optimizer to apply and update the recipe while training.
        create_optimizer must exist in the parent class and should set
        self.optimizer to the optimizer state and optionally set self.scaler
        if using amp.
        """

        self._check_super_defined("create_optimizer")
        super().create_optimizer()

        # n_gpu handled internally by dataloader
        total_batch_size = (
            self.args.per_device_train_batch_size
            * self.args.gradient_accumulation_steps
        )

        if isinstance(self.train_dataset, IterableDataset):
            logger.warning(
                "Training is being run with a streamed dataset, "
                "steps_per_epoch cannot be determined and will default to "
                "1. LLM Compressor modifiers utilizing this statistic may not "
                "behave as expected. "
            )
            self.total_steps_per_epoch = 1
        else:
            self.total_steps_per_epoch = math.ceil(
                len(self.train_dataset) / total_batch_size
            )

        active_session().initialize(
            optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch
        )

        return self.optimizer

    def create_scheduler(
        self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
    ):
        """
        Create an LR scheduler to work with the applied recipes. This is a placeholder
        that just calls the super method, but would be expanded upon if we ever
        implement a LearningRateModifier.

        :param num_training_steps: the total number of training steps
        :param optimizer: pre-initialized optimizer
        """

        # TODO: we don't currently have a LR scheduler in the new modifier framework
        self._check_super_defined("create_scheduler")
        return super().create_scheduler(
            num_training_steps=num_training_steps, optimizer=optimizer
        )

    def training_step(
        self,
        model: torch.nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        num_items_in_batch: Optional[int] = None,
    ) -> torch.Tensor:
        """
        Overrides the Trainer's training step to trigger the batch_start callback to
        the modifiers, then calls the parent function.

        :param model: the model to compute the loss for
        :param inputs: the inputs to pass through the model for calculating the loss
        :return: output of the model
        """
        self._check_super_defined("training_step")

        callbacks.batch_start(batch_data=inputs, global_step=self.state.epoch)
        model_outputs = super().training_step(
            model=model, inputs=inputs, num_items_in_batch=num_items_in_batch
        )

        return model_outputs

    def compute_loss(
        self,
        model: Module,
        inputs: Dict[str, Any],
        return_outputs: bool = False,
        num_items_in_batch: Optional[int] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
        """
        Override for the compute_loss to factor trigger callbacks and filter columns

        :param model: the model to compute the loss for
        :param inputs: the inputs to pass through the model for calculating the loss
        :param return_outputs: True to return the outputs with the loss,
            False otherwise
        :return: the resulting loss if not return_outputs, otherwise a tuple
            containing the loss and the model's outputs
        """
        self._check_super_defined("compute_loss")

        # TODO: do we need these model signature columns?
        inputs = {k: inputs[k] for k in inputs if k in self._signature_columns}
        loss = super().compute_loss(
            model=model,
            inputs=inputs,
            return_outputs=return_outputs,
            num_items_in_batch=num_items_in_batch,
        )

        # take the mean across multiple GPUs
        # this is done outside the compute_loss function in the parent, replicating it
        # here for LLM Compressor logging and distillation
        loss = loss.mean()

        # Log step-wise loss and perplexity, for llama-recipes comparison
        # we want this before distillation loss so perplexity isn't thrown off
        do_log = self.state.global_step % self.args.logging_steps == 0
        if do_log:
            log = {}
            log["step_loss"] = loss.item()
            log["perplexity"] = torch.exp(loss).item()

        if active_session().lifecycle.initialized_:
            state = callbacks.loss_calculated(loss=loss)
            if state and state.loss is not None:
                loss = state.loss
                if do_log:
                    log["distill_step_loss"] = loss.item() - log["step_loss"]
            callbacks.optim_pre_step()

        if do_log:
            self.log(log)

        return loss

    def train(self, *args, stage: Optional[str] = None, **kwargs):
        """
        Run a sparsification training cycle. Runs initialization for the sparse session
        before calling super().train() and finalization of the session after.

        Logs sparsification details for the trained model.

        :param args: positional args to pass to super().train()
        :param stage: Optional stage of recipe to run, or None to run all stages
        :param kwargs: keyword args to pass to super().train()
        :return: the output from super.train()
        """

        # lifecycle
        checkpoint, epoch = self._calculate_checkpoint_info(kwargs)
        self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)

        # do not save checkpoints as compressed
        original_save_compressed = self.model_args.save_compressed
        self.model_args.save_compressed = False

        # train with accelerator
        self.accelerator.wait_for_everyone()
        output = super().train(*args, **kwargs)
        self.accelerator.wait_for_everyone()

        # restore original setting for saving final model
        self.model_args.save_compressed = original_save_compressed

        # lifecycle
        self.finalize_session()
        self.accelerator.wait_for_everyone()

        # log model sparsity
        self.maybe_log_model_sparsification()
        self.accelerator.wait_for_everyone()

        return output

    # TODO: support all save args, not just skip_sparsity_compression_stats
    def save_model(
        self,
        output_dir: str,
        _internal_call: bool = False,
        skip_sparsity_compression_stats: Optional[bool] = False,
    ):
        """
        Override of the save_model function and expects it to exist in the parent.
        Calls into super() to save the model and additionally saves any recipes
        that were used with the model within the model folder.

        :param output_dir: the path to save the recipes into
        :param _internal_call: True if this is an internal call from
            the trainer in super(). Called from
            self.save_model(output_dir, _internal_call=True)
            in transformers/trainer/Trainer::_save_checkpoint

        """
        if active_session() is None:
            logger.warning(
                "No active session found, skipping saving of recipes and model."
            )
            return

        # knowledge distillation requires making wrappers transparent during
        if isinstance(self.model, KDModelWrapper):
            self.model.prepare_for_save()  # TODO: move to finalize

        # save checkpoint
        self.save_state()
        if self.accelerator.is_main_process:
            processor = getattr(self, "processing_class", self.tokenizer)
            # TODO: need to port over all saving parameters so that all
            # checkpoints are saved in the same way
            save_checkpoint(
                output_dir,
                model=self.model,
                processor=processor,
                save_safetensors=self.args.save_safetensors,
                save_compressed=self.model_args.save_compressed,
                skip_sparsity_compression_stats=skip_sparsity_compression_stats,
            )
        self.accelerator.wait_for_everyone()

        if isinstance(self.model, KDModelWrapper):
            self.model.finish_save()

    def maybe_log_model_sparsification(self):
        """
        Log info on model sparsity and quantization if possible. Only print logs on the
        main process, and avoid logging for quantized FSDP models
        """
        with summon_full_params_context(self.model, offload_to_cpu=True):
            # offload to avoid OOM errors
            if not self.accelerator.is_main_process:
                # only calculate stats rank0 GPU
                return
            if self.is_fsdp_enabled and qat_active(self.model):
                # due to state dict changes we can't log sparsity info with quantized
                # models in FSDP
                return

            self.log_model_sparsification()

    def log_model_sparsification(self):
        """
        Log the current model sparsification info including pruned and quantized states
        """
        sparsification_info = ModuleSparsificationInfo(self.model)

        logger.info(
            f"Sparsification info for {type(self.model).__name__}: "
            f"{sparsification_info.params_total} total params. "
        )
        sparsity_percent_formatted = "{:.2f}".format(
            sparsification_info.params_sparse_percent
        )
        logger.info(
            f"There are {sparsification_info.params_total} prunable "
            f"params which have {sparsity_percent_formatted}% "
            "avg sparsity."
        )

        quant_percent_formatted = "{:.2f}".format(
            sparsification_info.params_quantized_percent
        )
        logger.info(
            f"There are {sparsification_info.params_total} quantizable "
            f"params, with a quantization percentage of "
            f"{quant_percent_formatted}%."
        )

    def _prepare_model_for_fsdp(self):
        """
        Sets up FSDP ahead of time so we can run one-shot in FSDP mode
        """
        self.model.to("cpu")
        self.model = self.accelerator.prepare(self.model)
        self.accelerator.wait_for_everyone()

        if self.teacher is not None:
            self.teacher.to("cpu")
            for n, p in self.teacher.named_parameters():
                p.requires_grad = False
            self.teacher = self.accelerator.prepare(self.teacher)
            self.teacher.eval()
            self.accelerator.wait_for_everyone()

    def _extract_metadata(
        self,
        metadata_args: List[str],
        training_args_dict: Dict[str, Any],
        dataset_args_dict: Dict[str, Any],
    ) -> Dict[str, Any]:
        metadata = {}
        if not training_args_dict.keys().isdisjoint(dataset_args_dict.keys()):
            raise ValueError(
                "Found common keys in `training_args` and `data args`. "
                "This is prohibitive and may lead to undesired behavior."
            )

        args_dict = {**training_args_dict, **dataset_args_dict}

        for arg in metadata_args:
            if arg not in args_dict.keys():
                logger.warning(
                    f"Required metadata argument {arg} was not found "
                    f"in the training arguments. Setting {arg} to None."
                )
                metadata[arg] = None
            else:
                metadata[arg] = args_dict[arg]

        return metadata

    def _check_super_defined(self, func: str):
        if not hasattr(super(), func):
            raise NotImplementedError(
                f"The super class for SessionManagerMixIn must define a {func} function"
            )

    def _calculate_checkpoint_info(self, kwargs) -> Tuple[Optional[str], float]:
        """
        If resuming from checkpoint is set, get checkpoint and epoch to resume from
        """
        checkpoint = None
        epoch = 0.0

        if not kwargs or "resume_from_checkpoint" not in kwargs:
            logger.warning(
                "resume_from_checkpoint not passed into LLM Compressor Trainer.train. "
                "This will cause issues with restoring recipes when "
                "running from a checkpoint."
            )
        elif kwargs["resume_from_checkpoint"]:
            if (
                isinstance(kwargs["resume_from_checkpoint"], bool)
                and kwargs["resume_from_checkpoint"]
            ):
                checkpoint = get_last_checkpoint(self.args.output_dir)
            else:
                checkpoint = kwargs["resume_from_checkpoint"]
            epoch = TrainerState.load_from_json(
                os.path.join(checkpoint, TRAINER_STATE_NAME)
            ).epoch

        return checkpoint, epoch

compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None)

Override for the compute_loss to factor trigger callbacks and filter columns

Parameters:

Name Type Description Default
model Module

the model to compute the loss for

required
inputs Dict[str, Any]

the inputs to pass through the model for calculating the loss

required
return_outputs bool

True to return the outputs with the loss, False otherwise

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Any]]

the resulting loss if not return_outputs, otherwise a tuple containing the loss and the model's outputs

Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def compute_loss(
    self,
    model: Module,
    inputs: Dict[str, Any],
    return_outputs: bool = False,
    num_items_in_batch: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
    """
    Override for the compute_loss to factor trigger callbacks and filter columns

    :param model: the model to compute the loss for
    :param inputs: the inputs to pass through the model for calculating the loss
    :param return_outputs: True to return the outputs with the loss,
        False otherwise
    :return: the resulting loss if not return_outputs, otherwise a tuple
        containing the loss and the model's outputs
    """
    self._check_super_defined("compute_loss")

    # TODO: do we need these model signature columns?
    inputs = {k: inputs[k] for k in inputs if k in self._signature_columns}
    loss = super().compute_loss(
        model=model,
        inputs=inputs,
        return_outputs=return_outputs,
        num_items_in_batch=num_items_in_batch,
    )

    # take the mean across multiple GPUs
    # this is done outside the compute_loss function in the parent, replicating it
    # here for LLM Compressor logging and distillation
    loss = loss.mean()

    # Log step-wise loss and perplexity, for llama-recipes comparison
    # we want this before distillation loss so perplexity isn't thrown off
    do_log = self.state.global_step % self.args.logging_steps == 0
    if do_log:
        log = {}
        log["step_loss"] = loss.item()
        log["perplexity"] = torch.exp(loss).item()

    if active_session().lifecycle.initialized_:
        state = callbacks.loss_calculated(loss=loss)
        if state and state.loss is not None:
            loss = state.loss
            if do_log:
                log["distill_step_loss"] = loss.item() - log["step_loss"]
        callbacks.optim_pre_step()

    if do_log:
        self.log(log)

    return loss

create_optimizer()

Override the optimizer to apply and update the recipe while training. create_optimizer must exist in the parent class and should set self.optimizer to the optimizer state and optionally set self.scaler if using amp.

Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def create_optimizer(self):
    """
    Override the optimizer to apply and update the recipe while training.
    create_optimizer must exist in the parent class and should set
    self.optimizer to the optimizer state and optionally set self.scaler
    if using amp.
    """

    self._check_super_defined("create_optimizer")
    super().create_optimizer()

    # n_gpu handled internally by dataloader
    total_batch_size = (
        self.args.per_device_train_batch_size
        * self.args.gradient_accumulation_steps
    )

    if isinstance(self.train_dataset, IterableDataset):
        logger.warning(
            "Training is being run with a streamed dataset, "
            "steps_per_epoch cannot be determined and will default to "
            "1. LLM Compressor modifiers utilizing this statistic may not "
            "behave as expected. "
        )
        self.total_steps_per_epoch = 1
    else:
        self.total_steps_per_epoch = math.ceil(
            len(self.train_dataset) / total_batch_size
        )

    active_session().initialize(
        optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch
    )

    return self.optimizer

create_scheduler(num_training_steps, optimizer=None)

Create an LR scheduler to work with the applied recipes. This is a placeholder that just calls the super method, but would be expanded upon if we ever implement a LearningRateModifier.

Parameters:

Name Type Description Default
num_training_steps int

the total number of training steps

required
optimizer Optimizer

pre-initialized optimizer

None
Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def create_scheduler(
    self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
    """
    Create an LR scheduler to work with the applied recipes. This is a placeholder
    that just calls the super method, but would be expanded upon if we ever
    implement a LearningRateModifier.

    :param num_training_steps: the total number of training steps
    :param optimizer: pre-initialized optimizer
    """

    # TODO: we don't currently have a LR scheduler in the new modifier framework
    self._check_super_defined("create_scheduler")
    return super().create_scheduler(
        num_training_steps=num_training_steps, optimizer=optimizer
    )

finalize_session()

Wrap up training by finalizing all modifiers initialized in the current session

Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def finalize_session(self):
    """
    Wrap up training by finalizing all modifiers initialized in the current session
    """
    session = active_session()
    if not session.lifecycle.initialized_ or session.lifecycle.finalized:
        return False

    with summon_full_params_context(self.model, offload_to_cpu=True):
        # in order to update each layer we need to gathers all its parameters
        active_session().finalize()
    logger.info("Finalized LLM Compressor session")
    model = get_session_model()
    self.model = model
    torch.cuda.empty_cache()

initialize_session(epoch, checkpoint=None, stage=None)

Initialize the CompressionSession from the specified epoch, evaluates the recipe and initialized the modifiers for the training session

Parameters:

Name Type Description Default
epoch float

Epoch to initialize session from, usually 0 unless loading from a checkpoint

required
checkpoint Optional[str]

Optional checkpoint to initialize from to continue training

None
stage Optional[str]

Optional stage of recipe to run, or None to run all stages

None
Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def initialize_session(
    self,
    epoch: float,
    checkpoint: Optional[str] = None,
    stage: Optional[str] = None,
):
    """
    Initialize the CompressionSession from the specified epoch, evaluates the recipe
    and initialized the modifiers for the training session

    :param epoch: Epoch to initialize session from, usually 0 unless loading
    from a checkpoint
    :param checkpoint: Optional checkpoint to initialize from to continue training
    :param stage: Optional stage of recipe to run, or None to run all stages
    """
    session = active_session()
    if session.lifecycle.initialized_ or session.lifecycle.finalized:
        return False

    train_data = self.get_train_dataloader()

    self.accelerator.wait_for_everyone()
    with summon_full_params_context(self.model, offload_to_cpu=True):
        active_session().initialize(
            recipe=self.recipe,
            recipe_stage=stage,
            recipe_args=self.recipe_args,
            model=self.model,
            teacher_model=self.teacher,  # TODO: what about for self/disable?
            train_data=train_data,
            start=epoch,
            copy_data=False,
            attach_optim_callbacks=True,
            fsdp_active=self.is_fsdp_enabled,
            metadata=self.metadata,
        )

    self.accelerator.wait_for_everyone()
    model = get_session_model()
    self.model_wrapped = self.model = model

    if self.recipe is None:
        logger.warning(
            "No training recipe was provided, finetuning will be run "
            "without event callbacks to LLM Compressor. To supply a recipe "
            "pass a yaml file or string to the `recipe` argument."
        )

    torch.cuda.empty_cache()

log_model_sparsification()

Log the current model sparsification info including pruned and quantized states

Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def log_model_sparsification(self):
    """
    Log the current model sparsification info including pruned and quantized states
    """
    sparsification_info = ModuleSparsificationInfo(self.model)

    logger.info(
        f"Sparsification info for {type(self.model).__name__}: "
        f"{sparsification_info.params_total} total params. "
    )
    sparsity_percent_formatted = "{:.2f}".format(
        sparsification_info.params_sparse_percent
    )
    logger.info(
        f"There are {sparsification_info.params_total} prunable "
        f"params which have {sparsity_percent_formatted}% "
        "avg sparsity."
    )

    quant_percent_formatted = "{:.2f}".format(
        sparsification_info.params_quantized_percent
    )
    logger.info(
        f"There are {sparsification_info.params_total} quantizable "
        f"params, with a quantization percentage of "
        f"{quant_percent_formatted}%."
    )

maybe_log_model_sparsification()

Log info on model sparsity and quantization if possible. Only print logs on the main process, and avoid logging for quantized FSDP models

Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def maybe_log_model_sparsification(self):
    """
    Log info on model sparsity and quantization if possible. Only print logs on the
    main process, and avoid logging for quantized FSDP models
    """
    with summon_full_params_context(self.model, offload_to_cpu=True):
        # offload to avoid OOM errors
        if not self.accelerator.is_main_process:
            # only calculate stats rank0 GPU
            return
        if self.is_fsdp_enabled and qat_active(self.model):
            # due to state dict changes we can't log sparsity info with quantized
            # models in FSDP
            return

        self.log_model_sparsification()

save_model(output_dir, _internal_call=False, skip_sparsity_compression_stats=False)

Override of the save_model function and expects it to exist in the parent. Calls into super() to save the model and additionally saves any recipes that were used with the model within the model folder.

Parameters:

Name Type Description Default
output_dir str

the path to save the recipes into

required
_internal_call bool

True if this is an internal call from the trainer in super(). Called from self.save_model(output_dir, _internal_call=True) in transformers/trainer/Trainer::_save_checkpoint

False
Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def save_model(
    self,
    output_dir: str,
    _internal_call: bool = False,
    skip_sparsity_compression_stats: Optional[bool] = False,
):
    """
    Override of the save_model function and expects it to exist in the parent.
    Calls into super() to save the model and additionally saves any recipes
    that were used with the model within the model folder.

    :param output_dir: the path to save the recipes into
    :param _internal_call: True if this is an internal call from
        the trainer in super(). Called from
        self.save_model(output_dir, _internal_call=True)
        in transformers/trainer/Trainer::_save_checkpoint

    """
    if active_session() is None:
        logger.warning(
            "No active session found, skipping saving of recipes and model."
        )
        return

    # knowledge distillation requires making wrappers transparent during
    if isinstance(self.model, KDModelWrapper):
        self.model.prepare_for_save()  # TODO: move to finalize

    # save checkpoint
    self.save_state()
    if self.accelerator.is_main_process:
        processor = getattr(self, "processing_class", self.tokenizer)
        # TODO: need to port over all saving parameters so that all
        # checkpoints are saved in the same way
        save_checkpoint(
            output_dir,
            model=self.model,
            processor=processor,
            save_safetensors=self.args.save_safetensors,
            save_compressed=self.model_args.save_compressed,
            skip_sparsity_compression_stats=skip_sparsity_compression_stats,
        )
    self.accelerator.wait_for_everyone()

    if isinstance(self.model, KDModelWrapper):
        self.model.finish_save()

train(*args, stage=None, **kwargs)

Run a sparsification training cycle. Runs initialization for the sparse session before calling super().train() and finalization of the session after.

Logs sparsification details for the trained model.

Parameters:

Name Type Description Default
args

positional args to pass to super().train()

()
stage Optional[str]

Optional stage of recipe to run, or None to run all stages

None
kwargs

keyword args to pass to super().train()

{}

Returns:

Type Description

the output from super.train()

Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def train(self, *args, stage: Optional[str] = None, **kwargs):
    """
    Run a sparsification training cycle. Runs initialization for the sparse session
    before calling super().train() and finalization of the session after.

    Logs sparsification details for the trained model.

    :param args: positional args to pass to super().train()
    :param stage: Optional stage of recipe to run, or None to run all stages
    :param kwargs: keyword args to pass to super().train()
    :return: the output from super.train()
    """

    # lifecycle
    checkpoint, epoch = self._calculate_checkpoint_info(kwargs)
    self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage)

    # do not save checkpoints as compressed
    original_save_compressed = self.model_args.save_compressed
    self.model_args.save_compressed = False

    # train with accelerator
    self.accelerator.wait_for_everyone()
    output = super().train(*args, **kwargs)
    self.accelerator.wait_for_everyone()

    # restore original setting for saving final model
    self.model_args.save_compressed = original_save_compressed

    # lifecycle
    self.finalize_session()
    self.accelerator.wait_for_everyone()

    # log model sparsity
    self.maybe_log_model_sparsification()
    self.accelerator.wait_for_everyone()

    return output

training_step(model, inputs, num_items_in_batch=None)

Overrides the Trainer's training step to trigger the batch_start callback to the modifiers, then calls the parent function.

Parameters:

Name Type Description Default
model Module

the model to compute the loss for

required
inputs Dict[str, Union[Tensor, Any]]

the inputs to pass through the model for calculating the loss

required

Returns:

Type Description
Tensor

output of the model

Source code in src/llmcompressor/transformers/finetune/session_mixin.py
def training_step(
    self,
    model: torch.nn.Module,
    inputs: Dict[str, Union[torch.Tensor, Any]],
    num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
    """
    Overrides the Trainer's training step to trigger the batch_start callback to
    the modifiers, then calls the parent function.

    :param model: the model to compute the loss for
    :param inputs: the inputs to pass through the model for calculating the loss
    :return: output of the model
    """
    self._check_super_defined("training_step")

    callbacks.batch_start(batch_data=inputs, global_step=self.state.epoch)
    model_outputs = super().training_step(
        model=model, inputs=inputs, num_items_in_batch=num_items_in_batch
    )

    return model_outputs

TextGenerationDataset

Bases: RegistryMixin

Base class for text datasets. Applies the following transformations to a dataset in order to prepare the dataset to be loaded by a dataloader

  1. Load dataset from huggingface or local cache
  2. Preprocess dataset according to preprocess function or chat/dataset template
  3. Tokenize dataset using model tokenizer/processor
  4. Apply post processing such as grouping text and/or adding labels for finetuning

Parameters:

Name Type Description Default
dataset_args DatasetArguments

configuration settings for dataset loading

required
split str

split from dataset to load, for instance test or train[:5%]

required
processor Processor

processor or tokenizer to use on dataset

required
Source code in src/llmcompressor/transformers/finetune/data/base.py
class TextGenerationDataset(RegistryMixin):
    """
    Base class for text datasets. Applies the following transformations to a dataset
    in order to prepare the dataset to be loaded by a dataloader

    1. Load dataset from huggingface or local cache
    2. Preprocess dataset according to preprocess function or chat/dataset template
    3. Tokenize dataset using model tokenizer/processor
    4. Apply post processing such as grouping text and/or adding labels for finetuning

    :param dataset_args: configuration settings for dataset loading
    :param split: split from dataset to load, for instance `test` or `train[:5%]`
    :param processor: processor or tokenizer to use on dataset
    """

    # used to mask out the prompt so prompt tokens do not contribute to training loss
    PROMPT_KEY = "prompt"

    def __init__(
        self,
        dataset_args: DatasetArguments,
        split: str,
        processor: Processor,
    ):
        self.dataset_args = dataset_args
        self.split = split
        self.processor = processor

        # get tokenizer
        self.tokenizer = getattr(self.processor, "tokenizer", self.processor)

        if self.tokenizer is not None:
            # fill in pad token
            if not self.tokenizer.pad_token:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # configure sequence length
            max_seq_length = dataset_args.max_seq_length
            if dataset_args.max_seq_length > self.tokenizer.model_max_length:
                logger.warning(
                    f"The max_seq_length passed ({max_seq_length}) is larger than "
                    f"maximum length for model ({self.tokenizer.model_max_length}). "
                    f"Using max_seq_length={self.tokenizer.model_max_length}."
                )
            self.max_seq_length = min(
                dataset_args.max_seq_length, self.tokenizer.model_max_length
            )

            # configure padding
            self.padding = (
                False
                if self.dataset_args.concatenate_data
                else "max_length"
                if self.dataset_args.pad_to_max_length
                else False
            )

        else:
            self.max_seq_length = None
            self.padding = False

    def __call__(self, add_labels: bool = True) -> DatasetType:
        dataset = self.dataset_args.dataset

        if isinstance(dataset, str):
            # load dataset: load from huggingface or disk
            dataset = self.load_dataset()
        logger.debug(f"Raw dataset: {get_columns(dataset)}")

        if self.preprocess is not None:
            # preprocess: apply template or preprocessing function
            dataset = self.map(
                dataset,
                self.preprocess,
                batched=False,
                num_proc=self.dataset_args.preprocessing_num_workers,
                load_from_cache_file=not self.dataset_args.overwrite_cache,
                desc="Preprocessing",
            )
            logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}")

        # rename and remove columns match processor kwargs
        dataset = self.rename_columns(dataset)
        logger.debug(f"Dataset after column renaming: {get_columns(dataset)}")

        # use processor.model_input_names to determine if the ds is already tokenized
        model_input_names = getattr(self.processor, "model_input_names", ["input_ids"])
        if not any(col_name in model_input_names for col_name in get_columns(dataset)):
            # tokenize/ process
            dataset = self.filter_tokenizer_args(dataset)
            logger.debug(f"Tokenizer args after filtering: {get_columns(dataset)}")
            dataset = self.map(
                dataset,
                self.tokenize,
                batched=False,  # batching is not well supported for vision processors
                keep_in_memory=True,  # bug occurs when not batched and not in memory,
                # subsequent ds.map calls are always batched,
                # regardless of `batched` argument
                remove_columns=get_columns(dataset),  # assumes that input names
                # and output names are disjoint
                num_proc=self.dataset_args.preprocessing_num_workers,
                load_from_cache_file=not self.dataset_args.overwrite_cache,
                desc="Tokenizing",
            )
            logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}")

        if self.dataset_args.concatenate_data:
            # postprocess: group text
            dataset = self.map(
                dataset,
                self.group_text,
                batched=True,
                num_proc=self.dataset_args.preprocessing_num_workers,
                load_from_cache_file=not self.dataset_args.overwrite_cache,
                desc="Concatenating data",
            )
            logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}")

        if add_labels:
            # postprocess: add labels
            dataset = self.map(
                dataset,
                self.add_labels,
                batched=False,  # not compatible with batching, need row lengths
                num_proc=self.dataset_args.preprocessing_num_workers,
                load_from_cache_file=not self.dataset_args.overwrite_cache,
                desc="Adding labels",
            )
            logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}")

        elif self.PROMPT_KEY in get_columns(dataset):
            dataset = dataset.remove_columns(self.PROMPT_KEY)
            logger.debug("Removed prompt key")

        logger.debug(f"Model kwargs after postprocessing: {get_columns(dataset)}")
        return dataset

    def load_dataset(self):
        """
        Load the raw dataset from Hugging Face, using cached copy if available

        :param cache_dir: disk location to search for cached dataset
        :return: the requested dataset
        """
        if self.dataset_args.dataset_path is not None:
            if self.dataset_args.dvc_data_repository is not None:
                self.dataset_args.raw_kwargs["storage_options"] = {
                    "url": self.dataset_args.dvc_data_repository
                }
                self.dataset_args.raw_kwargs["data_files"] = (
                    self.dataset_args.dataset_path
                )
            else:
                self.dataset_args.raw_kwargs["data_files"] = (
                    get_custom_datasets_from_path(
                        self.dataset_args.dataset_path,
                        self.dataset_args.dataset
                        if hasattr(self.dataset_args, "dataset")
                        else self.dataset_args.dataset_name,
                    )
                )

        logger.debug(f"Loading dataset {self.dataset_args.dataset}")
        return get_raw_dataset(
            self.dataset_args,
            None,
            split=self.split,
            streaming=self.dataset_args.streaming,
            **self.dataset_args.raw_kwargs,
        )

    @cached_property
    def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
        """
        The function must return keys which correspond to processor/tokenizer kwargs,
        optionally including PROMPT_KEY
        """
        preprocessing_func = self.dataset_args.preprocessing_func

        if callable(preprocessing_func):
            return preprocessing_func

        if isinstance(preprocessing_func, str):
            if ":" in preprocessing_func:
                # load func_name from "/path/to/file.py:func_name"
                return import_from_path(preprocessing_func)
            else:
                # load from the registry
                return PreprocessingFunctionRegistry.get_value_from_registry(
                    name=preprocessing_func
                )

        return self.dataset_template

    @property
    def dataset_template(self) -> Union[Callable[[Any], Any], None]:
        return None

    def rename_columns(self, dataset: DatasetType) -> DatasetType:
        # rename columns to match processor/tokenizer kwargs
        column_names = get_columns(dataset)
        if self.dataset_args.text_column in column_names and "text" not in column_names:
            logger.debug(f"Renaming column `{self.dataset_args.text_column}` to `text`")
            dataset = dataset.rename_column(self.dataset_args.text_column, "text")

        return dataset

    def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
        # assumes that inputs are not passed via self.processor.__call__ args and kwargs
        signature = inspect.signature(self.processor.__call__)
        tokenizer_args = set(
            key
            for key, param in signature.parameters.items()
            if param.kind not in (Kind.VAR_POSITIONAL, Kind.VAR_KEYWORD)
        )
        logger.debug(
            f"Found processor args `{tokenizer_args}`. Removing all other columns"
        )

        column_names = get_columns(dataset)
        return dataset.remove_columns(
            list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY]))
        )

    def tokenize(self, data: LazyRow) -> Dict[str, Any]:
        # separate prompt
        prompt = data.pop(self.PROMPT_KEY, None)

        # tokenize
        data = self.processor(
            **data,
            padding=self.padding,
            max_length=self.max_seq_length,
            truncation=True,
        )

        # store unpadded prompt so we can mask out correct number of elements in labels
        if prompt is not None:
            data[self.PROMPT_KEY] = self.processor(
                text=prompt,
                max_length=self.max_seq_length,
                truncation=True,
            )["input_ids"]

        return data

    def group_text(self, data: LazyRow) -> Dict[str, Any]:
        concatenated_data = {k: sum(data[k], []) for k in data.keys()}
        total_length = len(concatenated_data[list(data.keys())[0]])
        total_length = (total_length // self.max_seq_length) * self.max_seq_length
        result = {
            k: [
                t[i : i + self.max_seq_length]
                for i in range(0, total_length, self.max_seq_length)
            ]
            for k, t in concatenated_data.items()
        }
        return result

    def add_labels(self, data: LazyRow) -> LazyRow:
        if "pixel_values" in data:
            raise NotImplementedError(
                "Label masking for vision datasets has not been implemented yet"
            )

        # if the dataset uses prompts, mask them out so they don't contribute
        # to the loss calculation
        prompt_len = 0
        if self.PROMPT_KEY in data:
            prompt_len = len(data[self.PROMPT_KEY])
        data["labels"] = data["input_ids"].copy()
        data["labels"][:prompt_len] = [LABELS_MASK_VALUE] * prompt_len

        # mask out padding in the labels as well
        padding = len(data["attention_mask"]) - sum(data["attention_mask"])
        if padding > 0:
            data["labels"][-padding:] = [LABELS_MASK_VALUE] * padding
        return data

    def map(
        self,
        dataset: Union[Dataset, IterableDataset],
        function: Callable[[Any], Any],
        **kwargs,
    ) -> Union[Dataset, IterableDataset]:
        """
        Wrapper function around Dataset.map and IterableDataset.map.

        If the dataset is streaming (in the case of IterableDataset), non-applicable
        arguments are ignored and the dataset features are resolved
        """
        if isinstance(dataset, IterableDataset):
            # remove arguments that don't apply to streaming
            kwargs.pop("num_proc", None)
            kwargs.pop("load_from_cache_file", None)
            kwargs.pop("desc", None)
            kwargs.pop("keep_in_memory", None)

        dataset = dataset.map(function, **kwargs)

        if isinstance(dataset, IterableDataset):
            dataset = dataset._resolve_features()

        return dataset

preprocess cached property

The function must return keys which correspond to processor/tokenizer kwargs, optionally including PROMPT_KEY

load_dataset()

Load the raw dataset from Hugging Face, using cached copy if available

Parameters:

Name Type Description Default
cache_dir

disk location to search for cached dataset

required

Returns:

Type Description

the requested dataset

Source code in src/llmcompressor/transformers/finetune/data/base.py
def load_dataset(self):
    """
    Load the raw dataset from Hugging Face, using cached copy if available

    :param cache_dir: disk location to search for cached dataset
    :return: the requested dataset
    """
    if self.dataset_args.dataset_path is not None:
        if self.dataset_args.dvc_data_repository is not None:
            self.dataset_args.raw_kwargs["storage_options"] = {
                "url": self.dataset_args.dvc_data_repository
            }
            self.dataset_args.raw_kwargs["data_files"] = (
                self.dataset_args.dataset_path
            )
        else:
            self.dataset_args.raw_kwargs["data_files"] = (
                get_custom_datasets_from_path(
                    self.dataset_args.dataset_path,
                    self.dataset_args.dataset
                    if hasattr(self.dataset_args, "dataset")
                    else self.dataset_args.dataset_name,
                )
            )

    logger.debug(f"Loading dataset {self.dataset_args.dataset}")
    return get_raw_dataset(
        self.dataset_args,
        None,
        split=self.split,
        streaming=self.dataset_args.streaming,
        **self.dataset_args.raw_kwargs,
    )

map(dataset, function, **kwargs)

Wrapper function around Dataset.map and IterableDataset.map.

If the dataset is streaming (in the case of IterableDataset), non-applicable arguments are ignored and the dataset features are resolved

Source code in src/llmcompressor/transformers/finetune/data/base.py
def map(
    self,
    dataset: Union[Dataset, IterableDataset],
    function: Callable[[Any], Any],
    **kwargs,
) -> Union[Dataset, IterableDataset]:
    """
    Wrapper function around Dataset.map and IterableDataset.map.

    If the dataset is streaming (in the case of IterableDataset), non-applicable
    arguments are ignored and the dataset features are resolved
    """
    if isinstance(dataset, IterableDataset):
        # remove arguments that don't apply to streaming
        kwargs.pop("num_proc", None)
        kwargs.pop("load_from_cache_file", None)
        kwargs.pop("desc", None)
        kwargs.pop("keep_in_memory", None)

    dataset = dataset.map(function, **kwargs)

    if isinstance(dataset, IterableDataset):
        dataset = dataset._resolve_features()

    return dataset

is_model_ct_quantized_from_path(path)

Determine if model from path is quantized based on the config

Parameters:

Name Type Description Default
path str

path to the model or HF stub

required

Returns:

Type Description
bool

True if config contains quantization_config from the given path

Source code in src/llmcompressor/transformers/utils/helpers.py
def is_model_ct_quantized_from_path(path: str) -> bool:
    """
    Determine if model from path is quantized based
    on the config

    :param path: path to the model or HF stub
    :return: True if config contains quantization_config from the given path

    """
    config = AutoConfig.from_pretrained(path)
    if config is not None:
        if (
            hasattr(config, "quantization_config")
            and config.quantization_config["quant_method"] == "compressed-tensors"
        ):
            return True
    return False