Skip to content

llmcompressor.pipelines.sequential.helpers

Subgraph dataclass

Dataclass specifying an executable subgraph of a model graph

Parameters:

Name Type Description Default
graph Graph

subgraph of model graph

required
input_names Set[str]

argument names of the compiled forward function

required
consumed_names Set[str]

argument names which are not used by any subsequent subgraphs and can therefore be deleted from the intermediates cache

required
Source code in src/llmcompressor/pipelines/sequential/helpers.py
@dataclass
class Subgraph:
    """
    Dataclass specifying an executable subgraph of a model graph

    :param graph: subgraph of model graph
    :param input_names: argument names of the compiled forward function
    :param consumed_names: argument names which are not used by any subsequent subgraphs
        and can therefore be deleted from the intermediates cache
    """

    graph: Graph
    input_names: Set[str]
    consumed_names: Set[str]
    _code: Optional[PythonCode] = None

    def forward(self, *args, **kwargs) -> Dict[str, Any]:
        """
        Execute the operations within the subgraph

        :param \\*args: argument inputs to subgraph forward function
        :param \\**kwargs: keyword inputs to subgraph forward function
        :return keyword outputs of subgraph forward function (non-consumed variables):
        """
        if self._code is None:
            self._code = self.graph.python_code("self")
            exec(self._code.src, self._code.globals)

        forward_fn = self._code.globals.get("forward")

        try:
            outputs = forward_fn(*args, **kwargs)
        except Exception as exception:
            raise RuntimeError(
                "Raised an exception during execution of the following code:\n"
                f"```\n{add_line_numbers(self._code.src)}\n```\n"
                "This is likely due to a violation of shape assumptions made when "
                "tracing"
            ) from exception

        return outputs

forward(*args, **kwargs)

Execute the operations within the subgraph

Parameters:

Name Type Description Default
\*args

argument inputs to subgraph forward function

required
\**kwargs

keyword inputs to subgraph forward function

required

Returns:

Type Description
Dict[str, Any]
Source code in src/llmcompressor/pipelines/sequential/helpers.py
def forward(self, *args, **kwargs) -> Dict[str, Any]:
    """
    Execute the operations within the subgraph

    :param \\*args: argument inputs to subgraph forward function
    :param \\**kwargs: keyword inputs to subgraph forward function
    :return keyword outputs of subgraph forward function (non-consumed variables):
    """
    if self._code is None:
        self._code = self.graph.python_code("self")
        exec(self._code.src, self._code.globals)

    forward_fn = self._code.globals.get("forward")

    try:
        outputs = forward_fn(*args, **kwargs)
    except Exception as exception:
        raise RuntimeError(
            "Raised an exception during execution of the following code:\n"
            f"```\n{add_line_numbers(self._code.src)}\n```\n"
            "This is likely due to a violation of shape assumptions made when "
            "tracing"
        ) from exception

    return outputs

find_target_nodes(graph, targets)

Find all nodes whose execution is equivalent to executing the target modules. Note that these nodes are guaranteed to be treated as leaf nodes by SequentialTracer

Parameters:

Name Type Description Default
graph GraphModule

graph containing target nodes

required
targets Set[Module]

modules whose nodes are being searched for

required

Returns:

Type Description
Set[Node]

set of all nodes which call the target modules

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def find_target_nodes(graph: GraphModule, targets: Set[Module]) -> Set[Node]:
    """
    Find all nodes whose execution is equivalent to executing the target modules.
    Note that these nodes are guaranteed to be treated as leaf nodes by SequentialTracer

    :param graph: graph containing target nodes
    :param targets: modules whose nodes are being searched for
    :return: set of all nodes which call the target modules
    """
    return set(
        node
        for node in graph.graph.nodes
        if node.op == "call_module" and graph.get_submodule(node.target) in targets
    )

get_sequential_ancestors(model, targets)

Find modules which are call graph ancestors of the given sequential targets

Parameters:

Name Type Description Default
model Module

model containing sequential targets

required
targets Set[Module]

sequential targets to find ancestors of

required

Returns:

Type Description
Set[Module]

call graph ancestors of sequential targets

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def get_sequential_ancestors(model: Module, targets: Set[Module]) -> Set[Module]:
    """
    Find modules which are call graph ancestors of the given sequential targets

    :param model: model containing sequential targets
    :param targets: sequential targets to find ancestors of
    :return: call graph ancestors of sequential targets
    """
    ancestors = set()

    def is_ancestor(module: Module) -> bool:
        if module in ancestors or module in targets:
            return True

        # eagerly compute list in order to avoid early stopping and :. missing ancestors
        _is_ancestor = any([is_ancestor(child) for child in module.children()])
        if _is_ancestor:
            ancestors.add(module)

        return _is_ancestor

    is_ancestor(model)
    return ancestors

get_targets_from_modifiers(modifiers, model)

Infer sequential targets and ignore list from modifiers list

Parameters:

Name Type Description Default
model PreTrainedModel

model being calibrated

required
modifiers List[Modifier]

list of modifiers being applied during calibration

required

Returns:

Type Description
Tuple[List[str], List[str]]

list of sequential targets and list of modules to ignore for tracing

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def get_targets_from_modifiers(
    modifiers: List[Modifier], model: PreTrainedModel
) -> Tuple[List[str], List[str]]:
    """
    Infer sequential targets and ignore list from modifiers list

    :param model: model being calibrated
    :param modifiers: list of modifiers being applied during calibration
    :return: list of sequential targets and list of modules to ignore for tracing
    """
    # avoid circular import
    from llmcompressor.pipelines.registry import SEQUENTIAL_MODIFIERS

    sequential_modifiers = [
        modifier for modifier in modifiers if isinstance(modifier, SEQUENTIAL_MODIFIERS)
    ]

    if len(sequential_modifiers) >= 2:
        types = [type(modifier) for modifier in sequential_modifiers]
        logger.warning(
            "Cannot infer sequential targets from multiple sequential modifiers "
            f"({types}). Defaulting to {types[0]}"
        )
    elif len(sequential_modifiers) <= 0:
        types = [type(modifier) for modifier in modifiers]
        raise ValueError(f"Cannot infer sequential targets from list of {types}")

    modifier = sequential_modifiers[0]

    # infer sequential targets
    if modifier.sequential_targets is None:
        sequential_targets = get_no_split_params(model)
    elif isinstance(modifier.sequential_targets, str):
        sequential_targets = [modifier.sequential_targets]
    else:
        sequential_targets = modifier.sequential_targets

    return sequential_targets, modifier.ignore

get_tracer(model, sequential_targets, ignore)

Get a tracer specialized for the given model. The resulting tracer will not trace inside of sequential targets, nor any modules which are not call graph ancestors of sequential targets

Tracing within sequential targets is unnecessary, and tracing within offloaded modules may result in meta tensors being added to the model graph

Parameters:

Name Type Description Default
model Module

model being traced

required
sequential_targets Set[Module]

modules which are sequential targets

required
ignore Set[Module]

modules to ignore during tracing, in the future will specify functions and methods to skip during tracing

required
Source code in src/llmcompressor/pipelines/sequential/helpers.py
def get_tracer(
    model: Module, sequential_targets: Set[Module], ignore: Set[Module]
) -> HFTracer:
    """
    Get a tracer specialized for the given model. The resulting tracer will not trace
    inside of sequential targets, nor any modules which are not call graph ancestors of
    sequential targets

    Tracing within sequential targets is unnecessary, and tracing within offloaded
    modules may result in meta tensors being added to the model graph

    :param model: model being traced
    :param sequential_targets: modules which are sequential targets
    :param ignore: modules to ignore during tracing, in the future will specify
        functions and methods to skip during tracing
    """
    sequential_ancestors = get_sequential_ancestors(model, sequential_targets)
    offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m))

    # check unlikely case that ancestors have direct params which are offloaded
    offloaded_ancestors = offloaded_modules & sequential_ancestors
    if offloaded_ancestors:
        names = set(module.__class__.__name__ for module in offloaded_ancestors)
        logger.warning(
            "The following modules are call graph ancestors of sequential targets,"
            f"but also contain offloaded modules: {names}.\n"
            "These modules will not be traced, and any sequential target children will "
            "be executed jointly, which may lead to OOM errors"
        )

    class SequentialTracer(HFTracer):
        def create_arg(self, a: Any) -> Argument:
            # special extension allows models which depend on config values to be traced
            if isinstance(a, PretrainedConfig):
                kwargs = {k: self.create_arg(v) for k, v in a.to_dict().items()}
                return self.create_node("call_function", a.__class__, (), kwargs)

            else:
                return super().create_arg(a)

        def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool:
            return (
                module not in sequential_ancestors
                or module in offloaded_modules
                or module in ignore
            )

        def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph:
            if isinstance(root, Module):
                # due to a bug in Tracer.create_args_for_root (_patch_function),
                # we must unwrap function wrappers prior to tracing, for example
                # the `deprecate_kwarg` by transformers which wraps forward
                unwrapped_forward = inspect.unwrap(type(root).forward)

                # we override the class method because the
                # class method is the one being traced
                with patch_attr(type(root), "forward", unwrapped_forward):
                    return super().trace(root, *args, **kwargs)

            else:
                return super().trace(root, *args, **kwargs)

    return SequentialTracer()

graph_is_well_formed(graph)

A graph is well formed if and only if nodeA in NodeB.users <=> nodeB in Node.A.all_input_nodes

Parameters:

Name Type Description Default
graph Graph

graph being checked

required

Returns:

Type Description
bool

True if the graph is well formed, False otherwise

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def graph_is_well_formed(graph: Graph) -> bool:
    """
    A graph is well formed if and only if
    `nodeA in NodeB.users <=> nodeB in Node.A.all_input_nodes`

    :param graph: graph being checked
    :return: True if the graph is well formed, False otherwise
    """
    for node in graph.nodes:
        for user in node.users:
            if node not in user.all_input_nodes:
                return False

        for input_node in node.all_input_nodes:
            if node not in input_node.users:
                return False

        if len(node.users) != len(set(node.users)) or len(node.all_input_nodes) != len(
            set(node.all_input_nodes)
        ):
            return False

    return True

match_modules(model, target_names)

Find modules whose names match the patterns given by target_names

Parameters:

Name Type Description Default
model Module

model containing submodules to find

required
target_names List[str]

target patterns to find

required

Returns:

Type Description
Set[Module]

all submodules matching target_names

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
    """
    Find modules whose names match the patterns given by `target_names`

    :param model: model containing submodules to find
    :param target_names: target patterns to find
    :return: all submodules matching `target_names`
    """
    return set(
        module
        for name, module in model.named_modules()
        if find_name_or_class_matches(name, module, target_names)
    )

partition_graph(model, partitions)

Convert each partition into a Subgraph. Each Subgraph returns a dictionary mapping of output node names to their computed values. Note that the consumed_names attribute of each Subgraph remains empty, to be later populated by trace_consumed_names

Parameters:

Name Type Description Default
model Module

model which owns the produced Subgraphs

required
partitions List[List[Node]]

list of partitions, where each partition is a list of nodes belonging to that partition

required

Returns:

Type Description
List[Subgraph]

list of subgraphs in order of execution

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def partition_graph(model: Module, partitions: List[List[Node]]) -> List[Subgraph]:
    """
    Convert each partition into a Subgraph. Each Subgraph returns a dictionary mapping
    of output node names to their computed values. Note that the `consumed_names`
    attribute of each Subgraph remains empty, to be later populated by
    `trace_consumed_names`

    :param model: model which owns the produced Subgraphs
    :param partitions: list of partitions, where each partition is a list of nodes
        belonging to that partition
    :return: list of subgraphs in order of execution
    """
    subgraphs = []

    # create subgraphs
    for partition_nodes in partitions:
        # create a new graph for the partition
        graph = Graph(model)
        node_map = {}

        # add placeholders for inputs not in this subgraph. use set to deduplicate
        new_input_nodes = {
            input_node
            for node in partition_nodes
            for input_node in node.all_input_nodes
            if input_node not in partition_nodes and input_node.op
        }
        for input_node in new_input_nodes:
            node_map[input_node] = graph.placeholder(input_node.name)

        # add the nodes to subgraph
        for node in partition_nodes:
            node_map[node] = graph.node_copy(node, lambda n: node_map[n])

        # add an output node to collect all subgraph outputs into a dictionary
        if len(graph.find_nodes(op="output")) <= 0:
            output_dict = {
                node.name: node_map[node]
                for node in partition_nodes
                if any(user not in partition_nodes for user in node.users.keys())
            }
            graph.output(output_dict)

        # save the subgraph for this partition
        graph.lint()
        input_names = set(node.name for node in graph.nodes if node.op == "placeholder")
        subgraphs.append(
            Subgraph(
                graph=graph,
                input_names=input_names,
                consumed_names=set(),  # populated later
            )
        )

        assert graph_is_well_formed(graph)

    return subgraphs

populate_concrete_args(model, sample_input)

Creates concrete args which, unlike the equivalent function provided by transformers.utils.fx, creates default values for variadic arguments, which are needed by some models.

Parameters:

Name Type Description Default
model Module

model being traced

required
sample_input Dict

values used to symbolically trace the model. All arguments to the model.forward function which are not in the sample_input are considered concrete args

required

Returns:

Type Description
Dict

dictionary mapping concrete argument names to their default values

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def populate_concrete_args(model: Module, sample_input: Dict) -> Dict:
    """
    Creates concrete args which, unlike the equivalent function provided by
    transformers.utils.fx, creates default values for variadic arguments, which are
    needed by some models.

    :param model: model being traced
    :param sample_input: values used to symbolically trace the model. All arguments
        to the model.forward function which are not in the sample_input are considered
        concrete args
    :return: dictionary mapping concrete argument names to their default values
    """
    sig = inspect.signature(model.forward)

    concrete_args = {}
    for parameter in sig.parameters.values():
        if parameter.name in sample_input:
            continue
        if parameter.kind == inspect._ParameterKind.VAR_POSITIONAL:
            value = list()
        elif parameter.kind == inspect._ParameterKind.VAR_KEYWORD:
            value = dict()
        elif parameter.name == "use_cache":
            value = False
        else:
            value = parameter.default

        concrete_args[parameter.name] = value

    return concrete_args

topological_partition(graph, targets)

Partition the graph into partitions such that each target belongs to exactly one partition and executing each partition depends only on intermediate values produced by executing the partitions before it.

Parameters:

Name Type Description Default
graph GraphModule

graph being partitioned

required
targets Set[Module]

target modules which will be assigned to disjoint partitions

required

Returns:

Type Description
List[List[Node]]

list of partitions, where each partition is a list of nodes belonging to that partition

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def topological_partition(graph: GraphModule, targets: Set[Module]) -> List[List[Node]]:
    """
    Partition the graph into partitions such that each `target` belongs to exactly one
    partition and executing each partition depends only on intermediate values produced
    by executing the partitions before it.

    :param graph: graph being partitioned
    :param targets: target modules which will be assigned to disjoint partitions
    :return: list of partitions, where each partition is a list of nodes belonging to
        that partition
    """
    assert graph_is_well_formed(graph.graph)
    target_nodes = find_target_nodes(graph, targets)

    partitions: List[List[Node]] = [[]]
    remaining_indegrees = {
        node: len([node for node in node.all_input_nodes if node.op != "get_attr"])
        for node in graph.graph.nodes
    }
    partition_index = 0  # global counter

    # start with graph input nodes,
    # but delay the `get_attr` nodes as long as possible
    queue = deque(
        node
        for node in graph.graph.nodes
        if remaining_indegrees[node] == 0 and node.op != "get_attr"
    )
    while len(queue) > 0:
        node = queue.popleft()

        # assign to partition
        partitions[partition_index].append(node)

        # guarantee targets are assigned to disjoint partitions
        if node in target_nodes:
            partition_index += 1
            partitions.append([])

        # recurse on last indegree only in order to guarantee that
        # the node is assigned to maximal partition
        for user in node.users:
            remaining_indegrees[user] -= 1
            if remaining_indegrees[user] == 0:
                queue.append(user)

    # an ideal implementation would involve implicitly consolidating partition indices
    # so that each node is assigned to the maximum partition possible (in order to delay
    # execution as long as possible), but saving these nodes for last covers the most
    # common and costly case (get_attr)
    for node in graph.graph.find_nodes(op="get_attr"):
        user_partitions = []
        for user in node.users:
            for index in range(len(partitions)):
                if user in partitions[index]:
                    user_partitions.append(index)
                    break
        partition_index = min(user_partitions)
        partitions[partition_index].insert(0, node)

    assert set().union(*partitions) == set(graph.graph.nodes)
    return partitions

trace_consumed_names(subgraphs)

Populate the consumed_names attribute of each Subgraph according to when inputs are last used in order to vacate the intermediates cache and save memory

Parameters:

Name Type Description Default
subgraphs List[Subgraph]

list of subgraphs with empty consumed_names attributes

required
Source code in src/llmcompressor/pipelines/sequential/helpers.py
def trace_consumed_names(subgraphs: List[Subgraph]):
    """
    Populate the `consumed_names` attribute of each Subgraph according to when inputs
    are last used in order to vacate the `intermediates` cache and save memory

    :param subgraphs: list of subgraphs with empty `consumed_names` attributes
    """
    # populate consumed_names according to when inputs are last used
    # in order to vacate the `intermediates` cache and save memory
    all_input_names = set().union(*(subgraph.input_names for subgraph in subgraphs))
    for input_name in all_input_names:
        for subgraph in reversed(subgraphs):
            if input_name in subgraph.input_names:
                subgraph.consumed_names.add(input_name)
                break
        else:
            raise ValueError(f"Could not find input name {input_name} in subgraphs")

trace_subgraphs(model, sample_input, sequential_targets, ignore)

Trace a model to produce subgraphs, where each sequential target belongs to exactly one subgraph and where executing each subgraph in order is equivalent to executing the original model

Parameters:

Name Type Description Default
model PreTrainedModel

model being traced

required
sample_input Dict[str, Any]

inputs whose values will change during execution but whose len, bool, and contains values are assumed constant across batches

required
sequential_targets List[str]

list of patterns matching sequential targets

required
ignore List[str]

modules to ignore during tracing, in the future will specify functions and methods to skip during tracing

required

Returns:

Type Description
List[Subgraph]

a list of Subgraphs in order of execution

Source code in src/llmcompressor/pipelines/sequential/helpers.py
def trace_subgraphs(
    model: PreTrainedModel,
    sample_input: Dict[str, Any],
    sequential_targets: List[str],
    ignore: List[str],
) -> List[Subgraph]:
    """
    Trace a model to produce subgraphs, where each sequential target belongs to exactly
    one subgraph and where executing each subgraph in order is equivalent to executing
    the original model

    :param model: model being traced
    :param sample_input: inputs whose values will change during execution but whose
        __len__, __bool__, and __contains__ values are assumed constant across batches
    :param sequential_targets: list of patterns matching sequential targets
    :param ignore: modules to ignore during tracing, in the future will specify
        functions and methods to skip during tracing
    :return: a list of Subgraphs in order of execution
    """
    # find modules
    sequential_targets = match_modules(model, sequential_targets)
    ignore = match_modules(model, ignore)

    # initialize arguments
    tracer = get_tracer(model, sequential_targets, ignore)
    concrete_args = populate_concrete_args(model, sample_input)

    # trace
    with calibration_forward_context(model), HooksMixin.disable_hooks():
        graph = GraphModule(
            model,
            tracer.trace(
                model,
                dummy_inputs=sample_input,
                concrete_args=concrete_args,
                complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
                # bug in trace throws an error for variadic
                # args and kwargs in function signature
            ),
        )

    # copy metadata
    graph.config = model.config
    graph.class_for_deserialization = model.__class__
    graph.device = model.device

    # perform subgraph partition
    partitions = topological_partition(graph, sequential_targets)
    subgraphs = partition_graph(model, partitions)
    trace_consumed_names(subgraphs)

    return subgraphs