Cb Spyre Vision
Source examples/offline_inference/cb_spyre_vision.py.
"""
This example shows how to run offline inference using continuous batching.
NOTE: At the moment, if you are checking parity, things may not line up
unless you compare eager against the FMS cpu model, i.e.,
$ python cb_spyre_vision.py --backend eager --compare-target fms
"""
import argparse
import os
import platform
import time
import torch
from fms.models import get_model
from fms.utils import serialization
from fms.utils.generation import generate as fms_generate
from transformers import AutoModelForVision2Seq, AutoProcessor
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="ibm-granite/granite-vision-3.3-2b")
parser.add_argument(
"--max_model_len", "--max-model-len", type=int, default=8192
) # one image has a max context of ~5k
parser.add_argument("--max_num_seqs", "--max-num-seqs", type=int, default=2)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--num-prompts", "-n", type=int, default=1)
parser.add_argument(
"--max-tokens",
type=str,
default="8",
help="Comma separated list of max tokens to use for each prompt. "
"This list is repeated until prompts are exhausted.",
)
parser.add_argument("--backend", type=str, default="sendnn", choices=["eager", "sendnn"])
parser.add_argument(
"--compare-target",
type=str,
default="fms",
choices=["transformers", "fms"],
help="Target to compare results against on CPU.",
)
def get_vllm_prompts(num_prompts):
"""Get the vLLM prompts to be processed."""
template = "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n<image>\n{}\n<|assistant|>\n" # noqa: E501
images = [
ImageAsset("cherry_blossom").pil_image,
ImageAsset("stop_sign").pil_image,
]
instructions = [
"describe this image.",
"what is shown in this image?",
"what kind of flowers are these?",
]
prompts = []
for img in images:
width, height = img.size
for instr in instructions:
# Make the images smol so that this example can run faster,
# since we are not using a toy model here, and big images
# can take up tons of tokens
new_width = int(0.1 * width)
new_height = int(0.1 * height)
prompts.append(
{
"prompt": template.format(instr),
"multi_modal_data": {
"image": img.resize((new_width, new_height)),
},
}
)
prompts = prompts * (num_prompts // len(prompts) + 1)
return prompts[:num_prompts]
def compare_results(
prompts: list[str], outputs_a: list[str], outputs_b: list[str], name_a: str, name_b: str
):
"""Utils for comparing outputs from differing engines/implementations,
e.g., transformers & vLLM.
"""
print(f"Comparing {name_a} results with {name_b}")
print("===============")
any_differ = False
for idx, (result_a, result_b) in enumerate(zip(outputs_a, outputs_b)):
if result_a != result_b:
img_tok_idx = prompts[idx].index("<image>")
gen_prompt_idx = prompts[idx].index("<|assistant|>")
raw_prompt = prompts[idx][img_tok_idx:gen_prompt_idx].strip()
any_differ = True
print(f"Results for prompt {idx} differ!")
print(f"\nPrompt (no system/gen prompt):\n {repr(raw_prompt)}")
print(f"\n{name_a} generated text:\n {result_a}\n")
print(f"\n{name_b} generated text:\n {result_b}\n")
print("-----------------------------------")
if not any_differ:
print("\nAll results match!\n")
### Alternate implementations to compare against
def get_transformers_results(model_path, vllm_prompts):
"""Process the results for HF Transformers running on CPU."""
model = AutoModelForVision2Seq.from_pretrained(model_path)
return process_prompts(
model_path,
model,
vllm_prompts,
process_prompt_transformers,
)
def process_prompt_transformers(model, max_tokens, inputs):
"""Process a single prompt using a transformers model."""
return model.generate(**inputs, max_new_tokens=max_tokens)
def get_fms_results(model_path, vllm_prompts):
"""Process the results for FMS running on CPU."""
# head_dim expansion required for granite vision
serialization.extend_adapter("llava_next", "hf", ["weight_expansion_for_mismatched_head_dim"])
config_dict = {}
config_dict["head_dim"] = 128
# Load, but don't compile (compare to CPU)
model = get_model(
"hf_pretrained",
model_path,
data_type=torch.bfloat16, # Matches default in vLLM for this model
fused_weights=False,
override_hf_pretrained_config=True,
text_config=config_dict,
)
return process_prompts(
model_path,
model,
vllm_prompts,
process_prompt_fms,
)
def process_prompt_fms(model, max_tokens, inputs):
"""Process a single prompt using an FMS model."""
input_ids = inputs.pop("input_ids")
# May be better to use paged attn later on, but for now
# we just use sdpa to avoid having to deal with padding
# utils & position id management here
inputs["attn_name"] = "sdpa_causal"
return fms_generate(
model,
input_ids,
max_new_tokens=max_tokens,
use_cache=True,
do_sample=False, # Greedy decode
extra_kwargs=inputs,
prepare_model_inputs_hook=model.prepare_inputs_for_generation,
)
def process_prompts(model_path, model, vllm_prompts, process_prompt):
"""Generic wrapper for running generate on either transformers or FMS."""
processor = AutoProcessor.from_pretrained(model_path)
num_prompts = len(vllm_prompts)
generated_texts = []
for i in range(num_prompts):
# Prompts are preformatted, so don't worry about the chat template
vllm_req = vllm_prompts[i]
inputs = processor(
text=vllm_req["prompt"],
images=vllm_req["multi_modal_data"]["image"],
return_tensors="pt",
)
# NOTE: Image tokens are expanded in the llava next preprocessor
num_expanded_toks = inputs.input_ids.shape[1]
target_output = process_prompt(
model,
max_tokens[i],
inputs,
)
out_toks = target_output[0][num_expanded_toks:]
# Make sure not to include EOS, since vLLM
# doesn't return them, but FMS might.
generated_text = processor.decode(
out_toks,
skip_special_tokens=True,
)
generated_texts.append(generated_text)
return generated_texts
if __name__ == "__main__":
args = parser.parse_args()
max_num_seqs = args.max_num_seqs # defines the max batch size
if platform.machine() == "arm64":
print(
"Detected arm64 running environment. "
"Setting HF_HUB_OFFLINE=1 otherwise vllm tries to download a "
"different version of the model using HF API which might not work "
"locally on arm64."
)
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["VLLM_SPYRE_DYNAMO_BACKEND"] = args.backend
os.environ["VLLM_SPYRE_USE_CB"] = "1"
os.environ["VLLM_SPYRE_USE_CHUNKED_PREFILL"] = "1"
prompts = get_vllm_prompts(args.num_prompts)
# Set differing max_tokens so that the requests drop out of the batch at
# different times
max_tokens = [int(v) for v in args.max_tokens.split(",")]
max_tokens = max_tokens * (args.num_prompts // len(max_tokens) + 1)
max_tokens = max_tokens[: args.num_prompts]
sampling_params = [
SamplingParams(max_tokens=m, temperature=0.0, ignore_eos=True) for m in max_tokens
]
llm = LLM(
model=args.model,
tokenizer=args.model,
max_model_len=args.max_model_len,
max_num_seqs=max_num_seqs,
tensor_parallel_size=args.tp,
)
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
print("=============== GENERATE")
t0 = time.time()
vllm_outputs = llm.generate(prompts, sampling_params)
vllm_results = [x.outputs[0].text for x in vllm_outputs] # raw texts
raw_prompts = [prompt["prompt"] for prompt in prompts]
compare_target_map = {
"transformers": get_transformers_results,
"fms": get_fms_results,
}
# Since we always compare the results here, we don't bother
# printing the raw results yet, since the head_dim patch
# in FMS init tends to flood the logs anyway.
cpu_results = compare_target_map[args.compare_target](
model_path=args.model,
vllm_prompts=prompts,
)
compare_results(
prompts=raw_prompts,
outputs_a=cpu_results,
outputs_b=vllm_results,
name_a=f"{args.compare_target} [cpu]",
name_b="vllm [spyre]",
)