switching to high quality piper tts and added label translations

This commit is contained in:
Matthias Hinrichs
2026-01-29 23:48:19 +01:00
commit d80c619df9
3934 changed files with 1451600 additions and 0 deletions
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,98 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import logging
import os
import sys
from utils import (
chain_enc_dec_with_beamsearch,
export_summarization_edinit,
export_summarization_enc_dec_past,
onnx_inference,
)
# GLOBAL ENVS
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("generate")
def print_args(args):
for arg in vars(args):
logger.info(f"{arg}: {getattr(args, arg)}")
def user_command():
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument("--max_length", type=int, default=20, help="default to 20")
parent_parser.add_argument("--min_length", type=int, default=0, help="default to 0")
parent_parser.add_argument("-o", "--output", type=str, default="onnx_models", help="default name is onnx_models.")
parent_parser.add_argument("-i", "--input_text", type=str, default=None, help="input text")
parent_parser.add_argument("-s", "--spm_path", type=str, default=None, help="tokenizer model from sentencepice")
parent_parser.add_argument("-v", "--vocab_path", type=str, help="vocab dictionary")
parent_parser.add_argument("-b", "--num_beams", type=int, default=5, help="default to 5")
parent_parser.add_argument("--repetition_penalty", type=float, default=1.0, help="default to 1.0")
parent_parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
parent_parser.add_argument("--early_stopping", type=bool, default=False, help="default to False")
parent_parser.add_argument("--opset_version", type=int, default=14, help="minimum is 14")
parent_parser.add_argument("--no_encoder", action="store_true")
parent_parser.add_argument("--no_decoder", action="store_true")
parent_parser.add_argument("--no_chain", action="store_true")
parent_parser.add_argument("--no_inference", action="store_true")
required_args = parent_parser.add_argument_group("required input arguments")
required_args.add_argument(
"-m",
"--model_dir",
type=str,
required=True,
help="The directory contains input huggingface model. \
An official model like facebook/bart-base is also acceptable.",
)
print_args(parent_parser.parse_args())
return parent_parser.parse_args()
if __name__ == "__main__":
args = user_command()
if args.opset_version < 14:
raise ValueError(f"The minimum supported opset version is 14! The given one was {args.opset_version}.")
isExist = os.path.exists(args.output) # noqa: N816
if not isExist:
os.makedirs(args.output)
# beam search op only supports CPU for now
args.device = "cpu"
logger.info("ENV: CPU ...")
if not args.input_text:
args.input_text = (
"PG&E stated it scheduled the blackouts in response to forecasts for high winds "
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
if not args.no_encoder:
logger.info("========== EXPORTING ENCODER ==========")
export_summarization_edinit.export_encoder(args)
if not args.no_decoder:
logger.info("========== EXPORTING DECODER ==========")
export_summarization_enc_dec_past.export_decoder(args)
if not args.no_chain:
logger.info("========== CONVERTING MODELS ==========")
chain_enc_dec_with_beamsearch.convert_model(args)
if not args.no_inference:
logger.info("========== INFERENCING WITH ONNX MODEL ==========")
onnx_inference.run_inference(args)
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,329 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# This script evaluates accuracy of ONNX models for question-answering task on SQuAD data set.
# Example to evaluate raw and optimized model for CUDA in Linux:
# pip3 install datasets evaluate optimum transformers onnxruntime-gpu
#
# python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding
#
# python3 -m onnxruntime.transformers.optimizer \
# --input ./bert-large-uncased-whole-word-masking-finetuned-squad/model.onnx \
# --output ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
#
# python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding \
# --onnx ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
#
# Snippet of example output in A100:
# {'exact': 86.65089877010406, 'f1': 92.99433524952254, 'total': 10570, 'HasAns_exact': 86.65089877010406
# 'total_time_in_seconds': 81.69239814393222, 'samples_per_second': 129.387804008115,
# 'latency_in_seconds': 0.007728703703304846, 'provider': 'CUDAExecutionProvider',
# 'pretrained_model_name': 'bert-large-uncased-whole-word-masking-finetuned-squad',
# 'batch_size': 1, 'sequence_length': 384, 'use_io_binding': True}
import argparse
import csv
import os
import time
try:
from importlib.metadata import PackageNotFoundError, version
except ImportError:
from importlib_metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Any
from datasets import load_dataset
from evaluate import evaluator
from optimum.onnxruntime import ORTModelForQuestionAnswering
from optimum.version import __version__ as optimum_version
from packaging import version as version_check
from transformers import AutoTokenizer, pipeline
if version_check.parse(optimum_version) < version_check.parse("1.13.1"):
raise ImportError(f"Please install optimum>=1.13.1. Current version: {optimum_version}.")
PRETRAINED_SQUAD_MODELS = [
"bert-large-uncased-whole-word-masking-finetuned-squad",
"deepset/roberta-base-squad2",
"distilbert-base-cased-distilled-squad",
]
def get_package_version(package_name: str):
try:
return version(package_name)
except PackageNotFoundError:
return None
def load_onnx_model(
model_id: str, onnx_path: str | None = None, provider="CUDAExecutionProvider", use_io_binding: bool = False
):
"""Load onnx model given pretrained model name and optional ONNX model path. If onnx_path is None,
the default onnx model from optimum will be used.
Args:
model_id (str): pretrained model name or checkpoint path
onnx_path (Optional[str], optional): path of onnx model to evaluate. Defaults to None.
Returns:
model: ORTModel for the onnx model
onnx_path: the path of onnx model
"""
if onnx_path is None:
# Export onnx to a sub-directory named by the model id
model = ORTModelForQuestionAnswering.from_pretrained(
model_id, export=True, provider=provider, use_io_binding=use_io_binding
)
save_onnx_dir = os.path.join(".", model_id)
model.save_pretrained(save_onnx_dir)
onnx_path = os.path.join(save_onnx_dir, "model.onnx")
print("Model is exported to onnx file:", onnx_path)
else:
model = ORTModelForQuestionAnswering.from_pretrained(
os.path.dirname(onnx_path),
file_name=Path(onnx_path).name,
provider=provider,
use_io_binding=use_io_binding,
# provider_options={"enable_skip_layer_norm_strict_mode": True},
)
return model, onnx_path
def output_details(results: list[dict[str, Any]], csv_filename: str):
"""Output a CSV file with detail of each test results.
Args:
results (List[Dict[str, Any]]): list of JSON results.
csv_filename (str): path of output CSV file
"""
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
column_names = [
"pretrained_model_name",
"onnx_path",
"provider",
"disable_fused_attention",
"batch_size",
"sequence_length",
"use_io_binding",
"exact",
"f1",
"total",
"HasAns_exact",
"HasAns_f1",
"HasAns_total",
"best_exact",
"best_exact_thresh",
"best_f1",
"best_f1_thresh",
"total_time_in_seconds",
"samples_per_second",
"latency_in_seconds",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for result in results:
csv_writer.writerow(result)
csv_file.flush()
print(f"Detail results are saved to csv file: {csv_filename}")
def output_summary(results: list[dict[str, Any]], csv_filename: str, metric_name: str):
"""Output a CSV file with summary of a metric on combinations of batch_size and sequence_length.
Args:
results (List[Dict[str, Any]]): list of JSON results.
csv_filename (str): path of output CSV file
metric_name (str): the metric to summarize
"""
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
header_names = [
"pretrained_model_name",
"onnx_path",
"provider",
"disable_fused_attention",
"use_io_binding",
]
model_list = list({result["onnx_path"] for result in results})
model_list.sort()
batch_sizes = list({result["batch_size"] for result in results})
batch_sizes.sort()
sequence_lengths = list({result["sequence_length"] for result in results})
sequence_lengths.sort()
key_names = []
for sequence_length in sequence_lengths:
for batch_size in batch_sizes:
key_names.append(f"b{batch_size}_s{sequence_length}")
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + key_names)
csv_writer.writeheader()
for model in model_list:
row = {}
# Metric value for given pair of batch_size and sequence_length.
# Assume that (onnx_path, batch_size and sequence_length) are unique so keep first occurrence only.
values = {}
values.update(dict.fromkeys(key_names, ""))
for result in results:
if result["onnx_path"] == model and result[metric_name]:
headers = {k: v for k, v in result.items() if k in header_names}
if not row:
row.update(headers)
batch_size = result["batch_size"]
sequence_length = result["sequence_length"]
key = f"b{batch_size}_s{sequence_length}"
if key in key_names:
values[key] = result[metric_name]
if row:
for key in key_names:
row[key] = values.get(key, "")
csv_writer.writerow(row)
csv_file.flush()
print(f"Summary results for {metric_name} are saved to csv file: {csv_filename}")
def main():
args = parse_arguments()
print(args)
for name in ["onnxruntime-gpu", "onnxruntime", "onnx", "torch", "transformers", "optimum", "datasets", "evaluate"]:
package_version = get_package_version(name)
if package_version:
print(f"{name} version", package_version)
pretrained_model_name = args.model_name
if args.onnx and not os.path.exists(args.onnx):
raise RuntimeError(f"Onnx model path does not exist: {args.onnx}")
disable_fused_attention = os.environ.get("ORT_DISABLE_FUSED_ATTENTION", "0") == "1"
all_results = []
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
for sequence_length in args.sequence_lengths:
tokenizer.model_max_length = sequence_length
tokenizer.doc_stride = min(sequence_length // 2, 128)
if args.onnx is None:
print("Exporting onnx model. It might take a few minutes...")
start_time = time.time()
ort_model, onnx_path = load_onnx_model(pretrained_model_name, args.onnx, args.provider, args.use_io_binding)
latency = time.time() - start_time
print(f"Onnx model exported or loaded in {latency:.1f} seconds")
print(ort_model.config)
if sequence_length > ort_model.config.max_position_embeddings:
raise RuntimeError("sequence length should not be larger than {ort_model.config.max_position_embeddings}")
qa_pipeline = pipeline(
"question-answering", model=ort_model, tokenizer=tokenizer, question_first=True, batch_size=args.batch_size
)
task_evaluator = evaluator("question-answering")
print("Loading dataset...")
start_time = time.time()
squad_dataset = load_dataset("squad", split=f"validation[:{args.total}]" if args.total > 0 else "validation")
latency = time.time() - start_time
print(f"Dataset loaded in {latency:.1f} seconds")
print("Evaluating squad_v2 with ORT. It might take a few minutes...")
start_time = time.time()
result = task_evaluator.compute(
model_or_pipeline=qa_pipeline,
data=squad_dataset,
metric="squad_v2",
squad_v2_format=True,
)
latency = time.time() - start_time
print(f"Evaluation done in {latency:.1f} seconds")
result["provider"] = args.provider
result["disable_fused_attention"] = disable_fused_attention
result["pretrained_model_name"] = pretrained_model_name
result["onnx_path"] = onnx_path
result["batch_size"] = args.batch_size
result["sequence_length"] = sequence_length
result["use_io_binding"] = args.use_io_binding
print(result)
all_results.append(result)
output_details(all_results, "detail.csv")
for metric_name in ["f1", "exact", "samples_per_second"]:
output_summary(all_results, f"{metric_name}.csv", metric_name)
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=False,
type=str,
default=PRETRAINED_SQUAD_MODELS[0],
help=f"Checkpoint directory or pre-trained model names in the list: {PRETRAINED_SQUAD_MODELS}",
)
parser.add_argument(
"-s",
"--sequence_lengths",
nargs="+",
type=int,
default=[384],
help="Sequence lengths for onnx model inputs. It could have multiple values.",
)
parser.add_argument(
"-b",
"--batch_size",
type=int,
default=1,
help="batch size for inference.",
)
parser.add_argument("-t", "--total", type=int, default=0, help="Total samples to test. 0 means all samples.")
parser.add_argument(
"--onnx",
required=False,
type=str,
default=None,
help="Optional onnx model path. If not specified, optimum will be used to export onnx model for testing.",
)
parser.add_argument(
"--provider",
required=False,
default="CUDAExecutionProvider",
help="Select which Execution Provider to use for runs. Default is CUDAExecutionProvider.",
)
parser.add_argument("--use_io_binding", required=False, action="store_true", help="Use IO Binding for GPU.")
parser.set_defaults(use_io_binding=False)
args = parser.parse_args(argv)
return args
if __name__ == "__main__":
main()
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,413 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script benchmarks gpt2 model with past state.
# For gpt2 model without past state, use benchmark.py to measure performance.
import argparse
import csv
import logging
import os
from datetime import datetime
import psutil
import torch
from benchmark_helper import (
Precision,
create_onnxruntime_session,
get_ort_environment_variables,
prepare_environment,
setup_logger,
)
from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
from packaging import version
from quantize_helper import QuantizeHelper
from transformers import AutoConfig
from transformers import __version__ as transformers_version
logger = logging.getLogger("")
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name_or_path",
required=True,
type=str,
help="Model path, or pretrained model name selected in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
)
parser.add_argument(
"--model_class",
required=False,
type=str,
default="GPT2LMHeadModel",
choices=list(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
parser.add_argument(
"--onnx_dir",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Directory to store onnx models",
)
parser.add_argument(
"--test_times",
required=False,
default=100,
type=int,
help="Number of repeat times to get average inference latency.",
)
parser.add_argument(
"-v",
"--validate_onnx",
required=False,
action="store_true",
help="Validate ONNX model",
)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument(
"--stage",
type=int,
default=0,
required=False,
choices=[0, 1, 2],
help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
"1 - decode the first token when past_sequence_length is zero; "
"2 - decode the remaining tokens when past_sequence_length is not zero; "
"0 - one onnx model for both stages 1 and 2. "
"Note that we will optimize 1 and 2 differently for best performance.",
)
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"-p",
"--precision",
type=Precision,
default=Precision.FLOAT32,
choices=list(Precision),
help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization",
)
parser.add_argument("--torchscript", required=False, action="store_true", help="use Torchscript")
parser.set_defaults(torchscript=False)
parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1], help="batch size")
parser.add_argument(
"--sequence_lengths",
nargs="+",
type=int,
default=[1],
help="sequence lengths (excluding past)",
)
parser.add_argument(
"-s",
"--past_sequence_lengths",
nargs="+",
type=int,
default=[8, 16, 32, 64, 128, 256],
help="past sequence lengths",
)
parser.add_argument(
"-r",
"--result_csv",
required=False,
default=None,
help="CSV file for saving summary results.",
)
parser.add_argument("--thread_num", required=False, type=int, default=-1, help="Threads to use")
parser.add_argument("--include_copy_output_latency", required=False, action="store_true")
parser.set_defaults(include_copy_output_latency=False)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument("--output_torch_latency", required=False, action="store_true")
parser.set_defaults(output_torch_latency=False)
parser.add_argument("--disable_io_binding", required=False, action="store_true")
parser.set_defaults(disable_io_binding=False)
args = parser.parse_args(argv)
return args
def main(args):
if version.parse(transformers_version) < version.parse(
"3.1.0"
): # past_key_values name does not exist in 3.0.2 or older
raise RuntimeError("This tool requires transformers 3.1.0 or later.")
logger.info(f"Arguments:{args}")
if args.precision == Precision.FLOAT16:
assert args.optimize_onnx and args.use_gpu, "fp16 requires --optimize_onnx --use_gpu"
if args.precision == Precision.INT8:
assert not args.use_gpu, "quantization only supports CPU"
if args.stage == 1:
assert args.past_sequence_lengths == [0], "past_sequence_lengths shall be 0 for stage==1 (init decoder)"
torch.set_num_threads(psutil.cpu_count(logical=True) if args.thread_num <= 0 else args.thread_num)
print(torch.__config__.parallel_info())
cache_dir = args.cache_dir
output_dir = args.onnx_dir
prepare_environment(cache_dir, output_dir, args.use_gpu)
model_class = MODEL_CLASSES[args.model_class][0]
gpt2helper = Gpt2Helper
config = AutoConfig.from_pretrained(args.model_name_or_path, torchscript=args.torchscript, cache_dir=cache_dir)
model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
# This script does not support float16 for PyTorch.
# if args.float16:
# model.half()
device = torch.device("cuda:0" if args.use_gpu else "cpu")
model.to(device)
use_external_data_format = config.n_layer > 24 # TODO: find a way to check model size > 2GB
onnx_model_paths = gpt2helper.get_onnx_paths(
output_dir,
args.model_name_or_path,
args.model_class,
has_past=True,
new_folder=use_external_data_format,
)
onnx_model_path = onnx_model_paths["raw"]
use_padding = MODEL_CLASSES[args.model_class][2]
gpt2helper.export_onnx(
model,
device,
onnx_model_path,
args.verbose,
use_external_data_format,
has_position_ids=use_padding,
has_attention_mask=use_padding,
)
if args.optimize_onnx or args.precision != Precision.FLOAT32:
onnx_model_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
gpt2helper.optimize_onnx(
onnx_model_paths["raw"],
onnx_model_path,
args.precision == Precision.FLOAT16,
model.config.num_attention_heads,
model.config.hidden_size,
use_external_data_format,
auto_mixed_precision=True,
stage=args.stage,
)
if args.precision == Precision.INT8:
logger.info("quantizing model...")
QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_paths["int8"], use_external_data_format)
model = QuantizeHelper.quantize_torch_model(model)
logger.info("finished quantizing model")
onnx_model_path = onnx_model_paths["int8"]
if args.torchscript:
model = gpt2helper.torchscript(
model,
config,
device,
has_position_ids=use_padding,
has_attention_mask=use_padding,
)
session = create_onnxruntime_session(
onnx_model_path,
args.use_gpu,
enable_all_optimization=False,
num_threads=args.thread_num,
verbose=args.verbose,
)
if session is None:
return
# Allocate output buffers for IO Binding
max_output_shapes = gpt2helper.get_output_shapes(
max(args.batch_sizes),
max(args.past_sequence_lengths),
max(args.sequence_lengths),
config,
args.model_class,
)
output_buffers = gpt2helper.get_output_buffers(max_output_shapes, device, args.precision == Precision.FLOAT16)
csv_filename = args.result_csv or "benchmark_result_{}.csv".format(datetime.now().strftime("%Y%m%d-%H%M%S"))
with open(csv_filename, mode="a", newline="") as csv_file:
column_names = [
"model_name",
"model_class",
"stage",
"environment_variables",
"gpu",
"precision",
"optimizer",
"torchscript",
"batch_size",
"sequence_length",
"past_sequence_length",
"disable_io_binding",
"torch_latency",
"onnxruntime_latency",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for batch_size in args.batch_sizes:
for sequence_length in args.sequence_lengths:
for past_sequence_length in args.past_sequence_lengths:
assert batch_size > 0 and sequence_length > 0 and past_sequence_length >= 0
logger.debug(
"Running test for batch_size=%d sequence_length=%d past_sequence_length=%d ...",
batch_size,
sequence_length,
past_sequence_length,
)
dummy_inputs = gpt2helper.get_dummy_inputs(
batch_size,
past_sequence_length,
sequence_length,
config.num_attention_heads,
config.hidden_size,
config.n_layer,
config.vocab_size,
device,
float16=(args.precision == Precision.FLOAT16),
has_position_ids=use_padding,
has_attention_mask=use_padding,
)
output_shapes = gpt2helper.get_output_shapes(
batch_size,
past_sequence_length,
sequence_length,
config,
args.model_class,
)
try:
if args.validate_onnx or args.output_torch_latency:
outputs, torch_latency = gpt2helper.pytorch_inference(model, dummy_inputs, args.test_times)
# Dump Torch output shape
for i, value in enumerate(outputs):
if isinstance(value, tuple):
logger.debug(
f"torch output {i} is tuple of size {len(value)}, shape {value[0].shape}"
)
else:
logger.debug(f"torch output {i} shape {value.shape}")
else:
outputs = None
torch_latency = None
if args.disable_io_binding:
ort_outputs, ort_latency = gpt2helper.onnxruntime_inference(
session, dummy_inputs, args.test_times
)
else:
ort_outputs, ort_latency = gpt2helper.onnxruntime_inference_with_binded_io(
session,
dummy_inputs,
output_buffers,
output_shapes,
args.test_times,
return_numpy=False,
include_copy_output_latency=args.include_copy_output_latency,
)
if args.validate_onnx:
copy_outputs = ort_outputs
if not args.disable_io_binding:
# Results of IO binding might be in GPU. Copy outputs to CPU for comparison.
copy_outputs = []
for output in ort_outputs:
copy_outputs.append(output.cpu().numpy())
if gpt2helper.compare_outputs(
outputs,
copy_outputs,
model_class=args.model_class,
rtol=DEFAULT_TOLERANCE[args.precision],
atol=DEFAULT_TOLERANCE[args.precision],
):
logger.info(
f"Pytorch and ONNX Runtime outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]})."
)
logger.info(
"batch_size=%d, sequence_length=%d, past_sequence_length=%d, onnxruntime_latency=%.2f %s %s",
batch_size,
sequence_length,
past_sequence_length,
ort_latency,
"(disable_io_binding)" if args.disable_io_binding else "",
", torch_latency={torch_latency}" if torch_latency else "",
)
row = {
"model_name": args.model_name_or_path,
"model_class": args.model_class,
"stage": args.stage,
"environment_variables": get_ort_environment_variables(),
"gpu": args.use_gpu,
"precision": args.precision,
"optimizer": args.optimize_onnx,
"torchscript": args.torchscript,
"batch_size": batch_size,
"sequence_length": sequence_length,
"past_sequence_length": past_sequence_length,
"disable_io_binding": args.disable_io_binding,
"torch_latency": f"{torch_latency:.2f}" if torch_latency else "None",
"onnxruntime_latency": f"{ort_latency:.2f}",
}
csv_writer.writerow(row)
except Exception:
logger.error("Exception", exc_info=True) # noqa: G201
return None
logger.info(f"Results are saved to file {csv_filename}")
return csv_filename
if __name__ == "__main__":
args = parse_arguments()
setup_logger(args.verbose)
main(args)
@@ -0,0 +1,558 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
"""
This converts GPT2 model to onnx. Examples:
(1) Convert pretrained model 'gpt2' to ONNX
python convert_to_onnx.py -m gpt2 --output gpt2.onnx
(2) Convert pretrained model 'distilgpt2' to ONNX, and use optimizer to get float16 model.
python convert_to_onnx.py -m distilgpt2 --output distilgpt2_fp16.onnx -o -p fp16
(3) Convert a model check point to ONNX, and run optimization and int8 quantization
python convert_to_onnx.py -m ./my_model_checkpoint/ --output my_model_int8.onnx -o -p int8
"""
import argparse
import csv
import json
import logging
import os
import shutil
import sys
from pathlib import Path
import numpy
import torch
from benchmark_helper import (
Precision,
create_onnxruntime_session,
get_ort_environment_variables,
prepare_environment,
setup_logger,
)
from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
from gpt2_tester import Gpt2Tester
from packaging import version
from quantize_helper import QuantizeHelper
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from onnxruntime import __version__ as ort_version
logger = logging.getLogger("")
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name_or_path",
required=True,
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
)
parser.add_argument(
"--model_class",
required=False,
type=str,
default="GPT2LMHeadModel",
choices=list(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
parser.add_argument(
"--output",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Output directory, or model path ends with .onnx",
)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--provider",
required=False,
default=None,
choices=["dml", "rocm", "migraphx", "cuda", "tensorrt"],
help="use dml, rocm, cuda, tensorrt or migraphx for respective backend",
)
parser.add_argument(
"--tolerance",
required=False,
type=float,
default=0,
help="the absolute and relative tolerance for parity verification",
)
parser.add_argument(
"--input_test_file",
"-i",
required=False,
type=str,
default="",
help="Path to the file with inputs to test with",
)
parser.add_argument(
"-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=list(Precision),
help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision, and int8 for quantization",
)
parser.add_argument(
"-t",
"--test_cases",
required=False,
type=int,
default=1000,
help="Number of test cases per run for parity",
)
parser.add_argument(
"-r",
"--test_runs",
required=False,
type=int,
default=10,
help="Number of runs for parity. It is used for significance test.",
)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
parser.set_defaults(use_external_data_format=False)
parser.add_argument("--overwrite", required=False, action="store_true")
parser.set_defaults(overwrite=False)
parser.add_argument(
"--use_int64_inputs",
required=False,
action="store_true",
help="Use int32 instead of int64 for input_ids, position_ids and attention_mask.",
)
parser.set_defaults(use_int64_inputs=False)
parser.add_argument(
"-s",
"--stage",
type=int,
default=0,
required=False,
choices=[0, 1, 2],
help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
"1 - decode the first token when past_sequence_length is zero; "
"2 - decode the remaining tokens when past_sequence_length is not zero; "
"0 - one onnx model for both stages 1 and 2. "
"Note that we will optimize 1 and 2 differently for best performance.",
)
fp16_option_group = parser.add_argument_group(
'float to float16 conversion parameters that works when "--precision fp16" is specified'
)
fp16_option_group.add_argument(
"-a",
"--auto_mixed_precision",
required=False,
action="store_true",
help="Convert to mixed precision automatically. Other float16 conversion parameters will be ignored.",
)
fp16_option_group.set_defaults(auto_mixed_precision=False)
fp16_option_group.add_argument(
"--keep_io_types",
required=False,
action="store_true",
help="Use float32 for past inputs, present and logits outputs.",
)
fp16_option_group.set_defaults(keep_io_types=False)
fp16_option_group.add_argument(
"--io_block_list",
nargs="+",
default=[],
help="List of inputs or outputs in float32 instead of float16",
)
fp16_option_group.add_argument(
"--op_block_list",
nargs="+",
default=[],
help="List of operators (like Add LayerNormalization SkipLayerNormalization EmbedLayerNormalization FastGelu) "
"to compute in float32 instead of float16.",
)
fp16_option_group.add_argument(
"--node_block_list",
nargs="+",
default=[],
help="List of node names to compute in float32 instead of float16.",
)
fp16_option_group.add_argument(
"--force_fp16_initializers",
required=False,
action="store_true",
help="Convert all float initializers to float16.",
)
fp16_option_group.set_defaults(force_fp16_initializers=False)
args = parser.parse_args(argv)
return args
def get_onnx_model_size(onnx_path: str, use_external_data_format: bool):
if not use_external_data_format:
return os.path.getsize(onnx_path)
else:
return sum([f.stat().st_size for f in Path(onnx_path).parent.rglob("*")])
def get_latency_name(batch_size, sequence_length, past_sequence_length):
return f"average_latency(batch_size={batch_size},sequence_length={sequence_length},past_sequence_length={past_sequence_length})"
def main(argv=None, experiment_name: str = "", run_id: str = "0", csv_filename: str = "gpt2_parity_results.csv"):
result = {}
if version.parse(transformers_version) < version.parse(
"3.1.0"
): # past_key_values name does not exist in 3.0.2 or older
raise RuntimeError("This tool requires transformers 3.1.0 or later.")
args = parse_arguments(argv)
setup_logger(args.verbose)
if not experiment_name:
experiment_name = " ".join(argv if argv else sys.argv[1:])
if args.tolerance == 0:
args.tolerance = DEFAULT_TOLERANCE[args.precision]
logger.info(f"Arguments:{args}")
cache_dir = args.cache_dir
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
prepare_environment(cache_dir, output_dir, args.use_gpu)
if args.precision != Precision.FLOAT32:
assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 requires --use_gpu"
if args.precision == Precision.INT8:
assert not args.use_gpu, "quantization only supports CPU"
model_class = MODEL_CLASSES[args.model_class][0]
use_padding = MODEL_CLASSES[args.model_class][2]
gpt2helper = Gpt2Helper
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=cache_dir)
model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
device = torch.device("cuda:0" if args.use_gpu else "cpu")
model.eval().to(device)
if (not args.use_external_data_format) and (config.n_layer > 24):
logger.info("Try --use_external_data_format when model size > 2GB")
onnx_model_paths = gpt2helper.get_onnx_paths(
output_dir,
args.model_name_or_path,
args.model_class,
new_folder=(args.precision == Precision.INT8),
remove_existing=["fp32", "fp16", "int8"],
) # Do not remove raw model to save time in parity test
raw_onnx_model = onnx_model_paths["raw"]
int_data_type = torch.int64 if args.use_int64_inputs else torch.int32
if os.path.exists(raw_onnx_model) and not args.overwrite:
logger.warning(f"Skip exporting ONNX model since it existed: {raw_onnx_model}")
else:
logger.info(f"Exporting ONNX model to {raw_onnx_model}")
gpt2helper.export_onnx(
model,
device,
raw_onnx_model,
args.verbose,
args.use_external_data_format,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
)
fp16_params = {"keep_io_types": args.keep_io_types}
if args.io_block_list:
fp16_params["keep_io_types"] = args.io_block_list
if args.node_block_list:
fp16_params["node_block_list"] = args.node_block_list
if args.op_block_list:
fp16_params["op_block_list"] = args.op_block_list
if args.force_fp16_initializers:
fp16_params["force_fp16_initializers"] = args.force_fp16_initializers
is_io_float16 = args.precision == Precision.FLOAT16 and not args.keep_io_types
optimized_ops = ""
all_ops = ""
if args.optimize_onnx or args.precision != Precision.FLOAT32:
output_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
logger.info(f"Optimizing model to {output_path}")
m = gpt2helper.optimize_onnx(
raw_onnx_model,
output_path,
args.precision == Precision.FLOAT16,
model.config.num_attention_heads,
model.config.hidden_size,
args.use_external_data_format,
auto_mixed_precision=args.auto_mixed_precision,
stage=args.stage,
**fp16_params,
)
nodes = m.nodes()
op_list = {node.op_type for node in nodes}
all_ops = ",".join(op_list)
# print optimized operators
optimized_op_counter = m.get_fused_operator_statistics()
if optimized_op_counter:
optimized_ops = ",".join([key for key in optimized_op_counter if optimized_op_counter[key] > 0])
else:
output_path = raw_onnx_model
if args.precision == Precision.INT8:
logger.info("quantizing model...")
QuantizeHelper.quantize_onnx_model(output_path, onnx_model_paths["int8"], args.use_external_data_format)
model = QuantizeHelper.quantize_torch_model(model)
logger.info("finished quantizing model")
output_path = onnx_model_paths["int8"]
if args.output.endswith(".onnx") and output_path != args.output and not args.use_external_data_format:
shutil.move(output_path, args.output)
output_path = args.output
logger.info(f"Output path: {output_path}")
model_size_in_MB = int(get_onnx_model_size(output_path, args.use_external_data_format) / 1024 / 1024) # noqa: N806
provider = args.provider
session = create_onnxruntime_session(
output_path, args.use_gpu, provider, enable_all_optimization=True, verbose=args.verbose
)
if args.model_class == "GPT2LMHeadModel" and session is not None:
parity_result = gpt2helper.test_parity(
session,
model,
device,
is_io_float16,
rtol=args.tolerance,
atol=args.tolerance,
model_class=args.model_class,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
test_cases_per_run=args.test_cases,
total_runs=args.test_runs,
stage=args.stage,
verbose=args.verbose,
)
# An example configuration for testing performance
batch_size = 8
sequence_length = 32 if args.stage == 1 else 1
past_sequence_length = 0 if args.stage == 1 else 32
latency = gpt2helper.test_performance(
session,
model,
device,
is_io_float16,
total_runs=100,
use_io_binding=True,
model_class=args.model_class,
has_position_ids=use_padding,
has_attention_mask=use_padding,
input_ids_dtype=int_data_type,
position_ids_dtype=int_data_type,
attention_mask_dtype=int_data_type,
batch_size=batch_size,
sequence_length=sequence_length,
past_sequence_length=past_sequence_length,
)
if args.precision == Precision.FLOAT16:
logger.info(f"fp16 conversion parameters:{fp16_params}")
# Write results to file
latency_name = get_latency_name(batch_size, sequence_length, past_sequence_length)
csv_file_existed = os.path.exists(csv_filename)
with open(csv_filename, mode="a", newline="") as csv_file:
column_names = [
"experiment",
"run_id",
"model_name",
"model_class",
"stage",
"gpu",
"precision",
"optimizer",
"test_cases",
"runs",
"keep_io_types",
"io_block_list",
"op_block_list",
"node_block_list",
"force_fp16_initializers",
"auto_mixed_precision",
"optimized_operators",
"operators",
"environment_variables",
"onnxruntime",
latency_name,
"top1_match_rate",
"onnx_size_in_MB",
"diff_50_percentile",
"diff_90_percentile",
"diff_95_percentile",
"diff_99_percentile",
"diff_pass_rate",
"nan_rate",
"top1_match_rate_per_run",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
if not csv_file_existed:
csv_writer.writeheader()
row = {
"experiment": experiment_name,
"run_id": run_id,
"model_name": args.model_name_or_path,
"model_class": args.model_class,
"stage": args.stage,
"gpu": args.use_gpu,
"precision": args.precision,
"optimizer": args.optimize_onnx,
"test_cases": args.test_cases,
"runs": args.test_runs,
"keep_io_types": args.keep_io_types,
"io_block_list": args.io_block_list,
"op_block_list": args.op_block_list,
"node_block_list": args.node_block_list,
"force_fp16_initializers": args.force_fp16_initializers,
"auto_mixed_precision": args.auto_mixed_precision,
"optimized_operators": optimized_ops,
"operators": all_ops,
"environment_variables": get_ort_environment_variables(),
"onnxruntime": ort_version,
latency_name: f"{latency:.2f}",
"diff_50_percentile": parity_result["max_diff_percentile_50"],
"diff_90_percentile": parity_result["max_diff_percentile_90"],
"diff_95_percentile": parity_result["max_diff_percentile_95"],
"diff_99_percentile": parity_result["max_diff_percentile_99"],
"diff_pass_rate": parity_result["diff_pass_rate"],
"nan_rate": parity_result["nan_rate"],
"top1_match_rate": parity_result["top1_match_rate"],
"top1_match_rate_per_run": parity_result["top1_match_rate_per_run"],
"onnx_size_in_MB": f"{model_size_in_MB}",
}
logger.info(f"result: {row}")
result.update(row)
csv_writer.writerow(row)
if args.input_test_file:
test_inputs = []
# Each line of test file is a JSON string like:
# {"input_ids": [[14698, 257, 1310, 13688, 319, 326]]}
with open(args.input_test_file) as read_f:
for _, line in enumerate(read_f):
line = line.rstrip() # noqa: PLW2901
data = json.loads(line)
input_ids = torch.from_numpy(numpy.asarray(data["input_ids"], dtype=numpy.int64)).to(device)
if use_padding:
if "attention_mask" in data:
numpy_float = numpy.float16 if is_io_float16 else numpy.float32
attention_mask = torch.from_numpy(numpy.asarray(data["attention_mask"], dtype=numpy_float)).to(
device
)
else:
padding = -1
attention_mask = (input_ids != padding).type(torch.float16 if is_io_float16 else torch.float32)
input_ids.masked_fill_(input_ids == padding, 0)
if "position_ids" in data:
position_ids = torch.from_numpy(numpy.asarray(data["position_ids"], dtype=numpy.int64)).to(
device
)
else:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(position_ids < 0, 0)
inputs = {
"input_ids": input_ids.to(int_data_type),
"position_ids": position_ids.to(int_data_type),
"attention_mask": attention_mask.to(int_data_type),
}
else:
inputs = {"input_ids": input_ids.to(int_data_type)}
test_inputs.append(inputs)
Gpt2Tester.test_generation(
session,
model,
device,
test_inputs,
precision=args.precision,
model_class=args.model_class,
top_k=20,
top_k_no_order=True,
max_steps=24,
max_inputs=0,
verbose=args.verbose,
save_test_data=3,
save_test_data_dir=Path(output_path).parent,
)
logger.info(f"Done. Output model: {output_path}")
return result
if __name__ == "__main__":
main()
@@ -0,0 +1,513 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script uses different configurations in mixed precision conversion for GPT-2 model, and
# measures the inference latency, top 1 match rate (compared to PyTorch FP32 model) and ONNX model size.
# It outputs a csv file with Mann-Whitney U test and T-Test on each pair of experiments, where
# pvalue < 0.05 means two experiments have significant difference on top 1 match rate.
# User could use this script to select the best mixed precision model according to these metrics.
import argparse
import csv
import datetime
import json
import logging
import os
import onnx
import scipy.stats
from benchmark_helper import get_ort_environment_variables, setup_logger
from convert_to_onnx import main
from gpt2_helper import PRETRAINED_GPT2_MODELS, Gpt2Helper
from onnx_model import OnnxModel
logger = logging.getLogger("")
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name_or_path",
required=True,
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
)
parser.add_argument(
"--csv",
required=False,
type=str,
default="gpt2_parity_results.csv",
help="path of csv file to save the result",
)
parser.add_argument(
"--test_cases",
required=False,
type=int,
default=500,
help="number of test cases per run",
)
parser.add_argument("--runs", required=False, type=int, default=40, help="number of repeated runs")
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--all",
required=False,
action="store_true",
help="run all combinations of mixed precision",
)
parser.set_defaults(all=False)
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
parser.set_defaults(use_external_data_format=False)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument(
"--skip_test",
required=False,
action="store_true",
help="do not run test, and only rank experiments based on existing csv file",
)
parser.set_defaults(skip_test=False)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing csv file",
)
parser.set_defaults(overwrite=False)
args = parser.parse_args(argv)
return args
class ParityTask:
def __init__(self, test_cases, total_runs, csv_path):
self.total_runs = total_runs
self.test_cases = test_cases
self.csv_path = csv_path
self.results = []
self.run_id = 0
def run(self, argv, experiment_name):
start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
run_id = f"{start_time}_{self.run_id}"
self.run_id += 1
try:
result = main(
[*argv, "-t", f"{self.test_cases}", "-r", f"{self.total_runs}"],
experiment_name=experiment_name,
run_id=run_id,
csv_filename=self.csv_path,
)
if result:
self.results.append(result)
except Exception:
logger.exception(f"Failed to run experiment {experiment_name}")
result = None
return result
def load_results_from_csv(csv_path):
rows = []
import csv # noqa: PLC0415
with open(csv_path, newline="") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
rows.append(row) # noqa: PERF402
return rows
def get_latency(row):
for name in row:
if name.startswith("average_latency(batch_size="):
return float(row[name])
raise RuntimeError("Failed to get average_latency from output")
def score(row):
"""Scoring function based on 3 metrics. The larger score is better."""
latency_in_ms = get_latency(row)
top1_match_rate = float(row["top1_match_rate"])
onnx_size_in_MB = float(row["onnx_size_in_MB"]) # noqa: N806
# A simple scoring function: cost of 0.1ms latency ~ 0.1% match rate ~ 100MB size
return top1_match_rate * 1000 - latency_in_ms * 10 - onnx_size_in_MB / 100
def print_wins(wins, rows, test_name):
print()
print("*" * 10)
row_map = {}
for row in rows:
row_map[row["run_id"]] = row
sorted_wins = dict(
sorted(
wins.items(),
key=lambda item: (item[1], score(row_map[item[0]])),
reverse=True,
)
)
logger.debug(f"{test_name} Wins:{sorted_wins}")
logger.info(f"Based on {test_name} wins and a scoring function, the ranking:")
rank = 0
previous_value = -1
for count, (key, value) in enumerate(sorted_wins.items()):
if value != previous_value:
rank = count
previous_value = value
for row in rows:
if row["run_id"] == key:
logger.info(
"{:02d}: WINs={:02d}, run_id={}, latency={:5.2f}, top1_match={:.4f}, size={}_MB, experiment={}, {}".format( # noqa: G001
rank,
value,
key,
get_latency(row),
float(row["top1_match_rate"]),
row["onnx_size_in_MB"],
row["experiment"],
get_ort_environment_variables(),
)
)
break
def run_significance_test(rows, output_csv_path):
"""Run U test and T test."""
utest_wins = {}
ttest_wins = {}
for row in rows:
run_id = row["run_id"]
utest_wins[run_id] = 0
ttest_wins[run_id] = 0
with open(output_csv_path, "w", newline="") as csvfile:
column_names = [
"model_name",
"run_id_1",
"experiment_1",
"top1_match_rate_1",
"run_id_2",
"experiment_2",
"top1_match_rate_2",
"U_statistic",
"U_pvalue",
"T_statistic",
"T_pvalue",
]
writer = csv.DictWriter(csvfile, fieldnames=column_names)
writer.writeheader()
required_match_columns = ["model_name", "test_cases", "runs"]
num_results = len(rows)
for i in range(num_results - 1):
result1 = rows[i]
if isinstance(result1["top1_match_rate_per_run"], str):
a = json.loads(result1["top1_match_rate_per_run"])
else:
a = result1["top1_match_rate_per_run"]
for j in range(i + 1, num_results, 1):
result2 = rows[j]
all_matched = True
for column in required_match_columns:
if result1[column] != result2[column]:
all_matched = False
break
if not all_matched:
continue
if isinstance(result2["top1_match_rate_per_run"], str):
b = json.loads(result2["top1_match_rate_per_run"])
else:
b = result2["top1_match_rate_per_run"]
try:
utest_statistic, utest_pvalue = scipy.stats.mannwhitneyu(
a, b, use_continuity=True, alternative="two-sided"
) # TODO: shall we use one-sided: less or greater according to "top1_match_rate"
except ValueError: # ValueError: All numbers are identical in mannwhitneyu
utest_statistic = None
utest_pvalue = None
ttest_statistic, ttest_pvalue = scipy.stats.ttest_ind(a, b, axis=None, equal_var=True)
if utest_pvalue is not None and utest_pvalue < 0.05:
if float(result1["top1_match_rate"]) > float(result2["top1_match_rate"]):
utest_wins[result1["run_id"]] += 1
else:
utest_wins[result2["run_id"]] += 1
if ttest_pvalue < 0.05:
if float(result1["top1_match_rate"]) > float(result2["top1_match_rate"]):
ttest_wins[result1["run_id"]] += 1
else:
ttest_wins[result2["run_id"]] += 1
row = {
"model_name": result1["model_name"],
"run_id_1": result1["run_id"],
"experiment_1": result1["experiment"],
"top1_match_rate_1": float(result1["top1_match_rate"]),
"run_id_2": result2["run_id"],
"experiment_2": result2["experiment"],
"top1_match_rate_2": float(result2["top1_match_rate"]),
"U_statistic": utest_statistic,
"U_pvalue": utest_pvalue,
"T_statistic": ttest_statistic,
"T_pvalue": ttest_pvalue,
}
writer.writerow(row)
logger.info(f"U-Test and T-Test results are output to {output_csv_path}")
print_wins(utest_wins, rows, "U-Test")
print_wins(ttest_wins, rows, "T-Test")
def get_last_matmul_node_name(raw_onnx_model: str):
model = onnx.load(raw_onnx_model)
onnx_model = OnnxModel(model)
output_name_to_node = onnx_model.output_name_to_node()
assert model.graph.output[0].name in output_name_to_node
node = output_name_to_node[model.graph.output[0].name]
if node.op_type == "MatMul":
logger.info(f"Found last MatMul node for logits: {node.name}")
return node.name
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
return None
def get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list):
model = args.model_name_or_path
parameters = f"-m {model} -o --use_gpu -p fp16".split()
if args.use_external_data_format:
parameters.append("--use_external_data_format")
parameters += [
"--io_block_list",
"logits",
"--node_block_list",
last_matmul_node_name,
]
if op_block_list:
parameters.extend(["--op_block_list", *op_block_list])
return parameters
def run_candidate(
task: ParityTask,
args,
last_matmul_node_name,
op_block_list=["FastGelu", "LayerNormalization"], # noqa: B006
):
parameters = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list)
op_block_list_str = ",".join(sorted(op_block_list))
if op_block_list:
name = f"Mixed precision baseline + {op_block_list_str} in FP32"
else:
name = f"Mixed precision baseline (logits output and last MatMul node {last_matmul_node_name} in FP32)"
env_vars = get_ort_environment_variables()
if env_vars:
name = name + f" ({env_vars})"
task.run(parameters, name)
def get_baselines(args):
model = args.model_name_or_path
fp32_baseline = f"-m {model} -o -p fp32".split()
if args.use_gpu:
fp32_baseline.append("--use_gpu")
if args.use_external_data_format:
fp32_baseline.append("--use_external_data_format")
fp16_baseline = f"-m {model} -o --use_gpu -p fp16".split()
if args.use_external_data_format:
fp16_baseline.append("--use_external_data_format")
return fp32_baseline, fp16_baseline
def run_tuning_step0(task, fp16_baseline, all_ops, optimized_ops):
"""Step 0 is to check which operator in FP16 causes most loss"""
fp32_logits = ["--io_block_list", "logits"]
task.run(fp16_baseline + fp32_logits, "FP16 except logits")
fp32_io = ["--keep_io_types"]
task.run(fp16_baseline + fp32_io, "Graph I/O FP32, Other FP16")
# Only weights in FP16
task.run(
fp16_baseline + fp32_io + ["--op_block_list"] + list(all_ops) + ["--force_fp16_initializers"],
"FP32 except weights in FP16",
)
optimized_ops_results = []
op_list = optimized_ops
for op in op_list:
op_block_list = ["--op_block_list"] + [o for o in op_list if o != op]
result = task.run(fp16_baseline + fp32_io + op_block_list, f"FP32 except {op} in FP16")
if result:
optimized_ops_results.append(result)
# Check which optimized operator causes the most loss in precision
min_result = min(optimized_ops_results, key=lambda y: y["top1_match_rate"])
print("step 0: optimized operator causes the most loss in precision", min_result)
def run_tuning_step1(task, mixed_precision_baseline, optimized_ops):
"""Step 1 is to figure out which optimized operator in FP32 could benefit most"""
for op in optimized_ops:
op_block_list = ["--op_block_list", op]
task.run(
mixed_precision_baseline + op_block_list,
f"Mixed precision baseline + {op} in FP32",
)
def run_tuning_step2(task, mixed_precision_baseline, optimized_ops):
"""Assumed that you have run step 0 and 1 to figure out that Logits FP32 and some operators shall be in FP32,
This step will try add one more operator.
"""
candidate_fp32_ops = ["FastGelu", "LayerNormalization", "SkipLayerNormalization"]
fp32_ops = [x for x in candidate_fp32_ops if x in optimized_ops]
for op in optimized_ops:
if op not in fp32_ops:
op_block_list = [*fp32_ops, op]
task.run(
[*mixed_precision_baseline, "--op_block_list", *op_block_list],
"Mixed precision baseline + {},{} in FP32".format(",".join(fp32_ops), op),
)
def run_parity(task: ParityTask, args):
onnx_model_paths = Gpt2Helper.get_onnx_paths(
"onnx_models",
args.model_name_or_path,
new_folder=args.use_external_data_format,
remove_existing=[],
)
fp32_baseline, fp16_baseline = get_baselines(args)
result = task.run(fp32_baseline, "FP32 baseline")
optimized_ops = []
if result and ("optimized_operators" in result) and result["optimized_operators"]:
optimized_ops = result["optimized_operators"].split(",")
else:
raise RuntimeError("Failed to get optimized operators")
all_ops = []
if result and ("operators" in result) and result["operators"]:
all_ops = result["operators"].split(",")
else:
raise RuntimeError("Failed to get operators")
# The following tests for fp16 requires GPU
if not args.use_gpu:
logger.info("skip mixed precision since --use_gpu is not specified")
return
task.run(fp16_baseline, "FP16 baseline")
last_matmul_node_name = get_last_matmul_node_name(onnx_model_paths["raw"])
# Mixed precision baseline
run_candidate(task, args, last_matmul_node_name, op_block_list=[])
def get_fp32_ops(x):
return [op for op in x if op in all_ops]
if args.all:
run_tuning_step0(task, fp16_baseline, all_ops, optimized_ops)
mixed_precision_baseline = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list=[])
run_tuning_step1(task, mixed_precision_baseline, optimized_ops)
run_tuning_step2(task, mixed_precision_baseline, optimized_ops)
else:
run_candidate(
task,
args,
last_matmul_node_name,
op_block_list=get_fp32_ops(["SkipLayerNormalization", "LayerNormalization", "Add"]),
)
run_candidate(task, args, last_matmul_node_name, op_block_list=["FastGelu"])
# Run a few good candidates
run_candidate(
task,
args,
last_matmul_node_name,
op_block_list=get_fp32_ops(["FastGelu", "SkipLayerNormalization", "LayerNormalization", "Add"]),
)
run_candidate(
task,
args,
last_matmul_node_name,
op_block_list=get_fp32_ops(
["FastGelu", "EmbedLayerNormalization", "SkipLayerNormalization", "LayerNormalization", "Add"]
),
)
if __name__ == "__main__":
args = parse_arguments()
setup_logger(args.verbose)
if args.test_cases < 100 or args.runs < 20 or args.test_cases * args.runs < 10000:
logger.warning(
"Not enough test cases or runs to get stable results or test significance. "
"Recommend test_cases >= 100, runs >= 20, test_cases * runs >= 10000."
)
if os.path.exists(args.csv) and not args.skip_test:
if not args.overwrite:
raise RuntimeError(
f"Output file {args.csv} existed. Please remove the file, or use either --skip_test or --overwrite."
)
else:
logger.info("Remove existing file %s since --overwrite is specified", args.csv)
os.remove(args.csv)
task = ParityTask(args.test_cases, args.runs, args.csv)
if not args.skip_test:
run_parity(task, args)
try:
rows = load_results_from_csv(task.csv_path)
except Exception:
logger.exception(f"Failed to load csv {task.csv_path}")
rows = task.results
logger.info("Start running significance tests...")
summary_csv = task.csv_path.replace(".csv", ".stats.csv")
run_significance_test(rows, summary_csv)
@@ -0,0 +1,501 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script helps evaluation of GPT-2 model.
import logging
import math
import os
import statistics
import timeit
import numpy
import torch
from benchmark_helper import Precision
from gpt2_helper import Gpt2Helper, Gpt2Inputs
logger = logging.getLogger(__name__)
class Gpt2Metric:
def __init__(self, treatment_name, baseline_name="Torch", top_k=20):
assert top_k > 1 and top_k <= 100
self.baseline = baseline_name
self.treatment = treatment_name
self.name: str = f"{treatment_name} vs {baseline_name}"
self.top_k = top_k
self.top_1_error: int = 0
self.top_k_error: int = 0
self.total_samples: int = 0
self.max_logits_diff: float = 0 # for non-empty past state
self.max_logits_diff_no_past: float = 0 # for empty past state
self.batch_top1_error: torch.FloatTensor = None # top 1 error for current batch
self.batch_topk_error: torch.FloatTensor = None # top k error for current batch
self.seq_len_latency = {}
def print(self):
if self.baseline != self.treatment:
print("---")
print(f"Metrics for {self.treatment} (baseline={self.baseline}):")
if self.total_samples > 0:
top_1_error_rate = 100.0 * self.top_1_error / self.total_samples
top_k_error_rate = 100.0 * self.top_k_error / self.total_samples
print(
f"Total={self.total_samples} Top1Error={self.top_1_error} ({top_1_error_rate:.2f}%) Top{self.top_k}Error={self.top_k_error} ({top_k_error_rate:.2f}%)"
)
print("Max logits diffs:")
print(f"\twith past = {self.max_logits_diff:.6f}")
print(f"\tempty past = {self.max_logits_diff_no_past:.6f}")
else:
print(f"Metrics for {self.treatment} (baseline):")
if self.seq_len_latency:
print("Past sequence length range and average latency:")
total = 0
count = 0
for key in sorted(self.seq_len_latency.keys()):
average = statistics.mean(self.seq_len_latency[key]) * 1000.0
if key == 0:
print(f"\t{key}: \t{average:.2f} ms")
else:
print(f"\t[{2**key}, {2 ** (key + 1) - 1}]:\t{average:.2f} ms")
total += average * len(self.seq_len_latency[key])
count += len(self.seq_len_latency[key])
print(f"Average Latency: {total / count:.2f} ms")
def diff_logits(self, baseline_logits, treatment_logits, is_empty_past: bool):
diff = (baseline_logits - treatment_logits).abs().max()
if is_empty_past:
self.max_logits_diff_no_past = max(self.max_logits_diff_no_past, diff)
else:
self.max_logits_diff = max(self.max_logits_diff, diff)
return diff
def start_batch(self, batch_size: int):
self.total_samples += batch_size
self.batch_top1_error = torch.zeros((batch_size, 1), dtype=torch.bool)
self.batch_topk_error = torch.zeros((batch_size, 1), dtype=torch.bool)
def eval_batch(self, baseline, treatment, past_seq_len, verbose=True):
self._eval_topk(baseline.top_1_tokens, treatment.top_1_tokens, 1, verbose)
self._eval_topk(baseline.top_k_tokens, treatment.top_k_tokens, self.top_k, verbose)
max_diff = self.diff_logits(baseline.logits, treatment.logits, past_seq_len == 0)
if verbose:
print(f"Max logits diffs of {self.name}: {max_diff}")
def _eval_topk(self, baseline_topk, treatment_topk, top_k, verbose=True):
if not torch.all(torch.eq(baseline_topk, treatment_topk)):
if top_k == 1:
if verbose:
print(f"Generated tokens not matched for {self.name}")
self.batch_top1_error |= torch.eq(baseline_topk, treatment_topk).logical_not()
else:
if verbose:
print(
f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results"
)
self.batch_topk_error |= (
torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) > 0
)
def end_batch(self):
self.top_1_error += self.batch_top1_error.sum()
self.top_k_error += self.batch_topk_error.sum()
def add_latency(self, past_seq_len, latency):
key = int(math.log2(past_seq_len)) + 1 if past_seq_len > 0 else 0
if key not in self.seq_len_latency:
self.seq_len_latency[key] = []
self.seq_len_latency[key].append(latency)
class Gpt2Tester:
def __init__(
self,
input_ids,
position_ids,
attention_mask,
num_attention_heads,
hidden_size,
num_layer,
device,
is_fp16=False,
top_k=20,
top_k_required_order=False,
):
self.batch_size = input_ids.shape[0]
self.input_length = input_ids.shape[1]
self.n_layer = num_layer
self.input_ids = input_ids
self.position_ids = position_ids
self.attention_mask = attention_mask
self.has_position_ids = position_ids is not None
self.has_attention_mask = attention_mask is not None
# Empty past state for first inference
self.past = []
past_shape = [
2,
self.batch_size,
num_attention_heads,
0,
hidden_size // num_attention_heads,
]
for _i in range(num_layer):
empty_past = torch.empty(past_shape).type(torch.float16 if is_fp16 else torch.float32)
self.past.append(empty_past.to(device))
self.logits = None
self.top_1_tokens = None
self.top_k_tokens = None
self.top_k = top_k
self.top_k_required_order = top_k_required_order
def get_inputs(self) -> Gpt2Inputs:
return Gpt2Inputs(self.input_ids, self.position_ids, self.attention_mask, self.past)
def save_test_data(self, session, output, save_test_data_dir, test_case_id):
from onnx import numpy_helper # noqa: PLC0415
path = os.path.join(save_test_data_dir, "test_data_set_" + str(test_case_id))
if os.path.exists(path):
print(f"Directory {path} existed. Skip saving test data")
return
os.makedirs(path, exist_ok=True)
def add_tensor(input_tensors, torch_tensor, name):
input_tensors.append(numpy_helper.from_array(torch_tensor.clone().cpu().numpy(), name))
input_tensors = []
add_tensor(input_tensors, self.input_ids, "input_ids")
if self.has_position_ids:
add_tensor(input_tensors, self.position_ids, "position_ids")
if self.has_attention_mask:
add_tensor(input_tensors, self.attention_mask, "attention_mask")
for i in range(self.n_layer):
add_tensor(input_tensors, self.past[i], "past_" + str(i))
for i, tensor in enumerate(input_tensors):
with open(os.path.join(path, f"input_{i}.pb"), "wb") as f:
f.write(tensor.SerializeToString())
output_names = [output.name for output in session.get_outputs()]
for i, _name in enumerate(output_names):
tensor = numpy_helper.from_array(
output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy()
)
with open(os.path.join(path, f"output_{i}.pb"), "wb") as f:
f.write(tensor.SerializeToString())
print(f"Test data saved to directory {path}")
def update(self, output, step, device):
"""
Update the inputs for next inference.
"""
self.logits = (
torch.from_numpy(output[0]) if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu()
)
self.top_1_tokens = Gpt2Tester.predict_next_token(self.logits)
self.top_k_tokens = Gpt2Tester.predict_next_token(self.logits, self.top_k, self.top_k_required_order)
self.input_ids = self.top_1_tokens.clone().detach().reshape([self.batch_size, 1]).to(device)
if self.has_position_ids:
self.position_ids = (
torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, 1).to(device)
)
if self.has_attention_mask:
self.attention_mask = torch.cat(
[
self.attention_mask,
torch.ones([self.batch_size, 1]).type_as(self.attention_mask),
],
1,
).to(device)
self.past = []
if isinstance(output[1], tuple): # past in torch output is tuple
self.past = list(output[1])
else:
for i in range(self.n_layer):
past_i = (
torch.from_numpy(output[i + 1])
if isinstance(output[i + 1], numpy.ndarray)
else output[i + 1].clone().detach()
)
self.past.append(past_i.to(device))
def diff(self, baseline):
"""
Compare inputs and logits output.
"""
print("start diff...")
if self.logits is not None:
max_io_diff = (self.logits - baseline.logits).abs().max()
if max_io_diff > 1e-4:
print(f"Max logits difference is too large: {max_io_diff}")
if not torch.all(self.input_ids == baseline.input_ids):
print("Input_ids is different", self.input_ids, baseline.input_ids)
if self.has_position_ids:
if not torch.all(self.position_ids == baseline.position_ids):
print(
"position_ids is different",
self.position_ids,
baseline.position_ids,
)
if self.has_attention_mask:
if not torch.all(self.attention_mask == baseline.attention_mask):
print(
"attention_mask is different",
self.attention_mask,
baseline.attention_mask,
)
assert len(self.past) == len(baseline.past)
for i, past_i in enumerate(self.past):
assert past_i.shape == baseline.past[i].shape
if past_i.nelement() > 0:
max_past_diff = (past_i - baseline.past[i]).abs().max()
if max_past_diff > 1e-4:
print(f"max_past_diff[{i}]={max_past_diff}")
@staticmethod
def predict_next_token(logits, top_k=1, required_order=False):
"""
Get top k topkens based on logits.
"""
# logits has shape (batch_size, seq_len, vocab_size)
# last token logits has shape (batch_size, vocab_size)
lastTokenLogits = logits[:, -1] # noqa: N806
if top_k == 1:
generatedTokens = torch.argmax(lastTokenLogits, 1, True) # noqa: N806
return generatedTokens
else:
topk = torch.argsort(lastTokenLogits, -1, descending=True)[:, :top_k]
if not required_order:
sorted_topk, _ = topk.sort()
return sorted_topk
return topk
@staticmethod
def diff_present(onnx_output, onnx_io_output, n_layer):
"""
Compare the present outputs of two outputs from ONNX Runtime.
"""
present_diff_max = []
for i in range(n_layer):
onnx_present_i = (
torch.from_numpy(onnx_output[i + 1])
if isinstance(onnx_output[i + 1], numpy.ndarray)
else onnx_output[i + 1]
)
onnx_io_present_i = (
torch.from_numpy(onnx_io_output[i + 1])
if isinstance(onnx_io_output[i + 1], numpy.ndarray)
else onnx_io_output[i + 1]
)
max_diff = (onnx_present_i - onnx_io_present_i).abs().max()
present_diff_max.append(max_diff)
print(f"present_diff_max={present_diff_max}")
@staticmethod
def is_quantized_onnx_model(onnx_model_path):
"""
Returns True if the ONNX model is quantized.
"""
from onnx import load # noqa: PLC0415
model = load(onnx_model_path)
from onnxruntime.quantization.quantize import __producer__ as quantize_producer # noqa: PLC0415
return model.producer_name == quantize_producer
@staticmethod
def test_generation(
session,
model,
device,
test_inputs,
precision=Precision.FLOAT32,
model_class="Gpt2LMHeadModel",
top_k=20,
top_k_no_order=True,
max_steps=24,
max_inputs=0,
verbose=False,
save_test_data=0,
save_test_data_dir=".",
):
"""
Test Generation using greedy beam search (without sampling) to compare PyTorch and ONNX model.
It will print top 1 and top k errors on the given test inputs.
"""
print(
f"start test generation: (top_k={top_k} top_k_no_order={top_k_no_order} max_steps={max_steps} test_inputs={len(test_inputs)} max_inputs={max_inputs})"
)
n_layer = model.config.n_layer
n_head = model.config.n_head
n_embd = model.config.n_embd
eos_token_id = model.config.eos_token_id
test_data_saved = 0
is_float16 = precision == Precision.FLOAT16
if is_float16:
assert "float16" in session.get_outputs()[0].type
# We will still use fp32 torch model as baseline when onnx model if fp16
model.eval().to(device)
# Allocate initial buffers for IO Binding of ONNX Runtimne. The buffer size will automatically increase later.
init_output_shapes = Gpt2Helper.get_output_shapes(
batch_size=4,
past_sequence_length=128,
sequence_length=32,
config=model.config,
model_class=model_class,
)
output_buffers = Gpt2Helper.get_output_buffers(init_output_shapes, device, is_float16=is_float16)
baseline_name = "Torch"
treatment_name = "Quantized Onnx" if precision == Precision.INT8 else "Onnx"
torch_metric = Gpt2Metric(baseline_name, baseline_name, top_k)
onnx_metric = Gpt2Metric(treatment_name, baseline_name, top_k)
onnx_io_metric = Gpt2Metric(treatment_name + " with IO Binding", baseline_name, top_k)
for i, inputs in enumerate(test_inputs):
if max_inputs > 0 and i == max_inputs:
break
if i % 10 == 0:
print(f"{i}")
input_ids = inputs["input_ids"]
position_ids = inputs.get("position_ids", None)
attention_mask = inputs.get("attention_mask", None)
onnx_runner = Gpt2Tester(
input_ids,
position_ids,
attention_mask,
n_head,
n_embd,
n_layer,
device,
is_float16,
top_k,
not top_k_no_order,
)
onnx_io_runner = Gpt2Tester(
input_ids,
position_ids,
attention_mask,
n_head,
n_embd,
n_layer,
device,
is_float16,
top_k,
not top_k_no_order,
)
torch_runner = Gpt2Tester(
input_ids,
position_ids,
attention_mask,
n_head,
n_embd,
n_layer,
device,
False,
top_k,
not top_k_no_order,
) # Torch model baseline is fp32
batch_size = torch_runner.batch_size
onnx_metric.start_batch(batch_size)
onnx_io_metric.start_batch(batch_size)
with torch.no_grad():
done = torch.zeros(batch_size, dtype=torch.bool)
for step in range(max_steps):
seq_len = list(onnx_runner.input_ids.size())[1]
past_seq_len = list(onnx_runner.past[0].size())[3]
start_time = timeit.default_timer()
pytorch_output = Gpt2Helper.pytorch_inference(model, torch_runner.get_inputs())
torch_metric.add_latency(past_seq_len, timeit.default_timer() - start_time)
torch_runner.update(pytorch_output, step, device)
onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference(
session, onnx_runner.get_inputs(), total_runs=1
)
onnx_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
onnx_runner.update(onnx_output, step, device)
output_shapes = Gpt2Helper.get_output_shapes(
batch_size,
past_seq_len,
seq_len,
model.config,
model_class=model_class,
)
Gpt2Helper.auto_increase_buffer_size(output_buffers, output_shapes)
(
onnx_io_output,
avg_latency_ms,
) = Gpt2Helper.onnxruntime_inference_with_binded_io(
session,
onnx_io_runner.get_inputs(),
output_buffers,
output_shapes,
total_runs=1,
return_numpy=False,
include_copy_output_latency=True,
)
onnx_io_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
if test_data_saved < save_test_data:
onnx_io_runner.save_test_data(session, onnx_io_output, save_test_data_dir, test_data_saved)
test_data_saved += 1
onnx_io_runner.update(onnx_io_output, step, device)
if verbose:
onnx_runner.diff(onnx_io_runner)
Gpt2Tester.diff_present(onnx_output, onnx_io_output, n_layer)
print("Top 1 tokens:")
print("\tTorch", torch_runner.top_1_tokens)
print("\tONNX", onnx_runner.top_1_tokens)
print("\tONNX with IO binding", onnx_io_runner.top_1_tokens)
onnx_metric.eval_batch(torch_runner, onnx_runner, past_seq_len, verbose=verbose)
onnx_io_metric.eval_batch(torch_runner, onnx_io_runner, past_seq_len, verbose=verbose)
done = done | (torch_runner.top_1_tokens == eos_token_id).any()
if torch.all(done):
break
onnx_metric.end_batch()
onnx_io_metric.end_batch()
torch_metric.print()
onnx_metric.print()
onnx_io_metric.print()
@@ -0,0 +1,146 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script helps debugging parity issue for two same onnx models with fp16 and fp32 format
# Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON
import math
import multiprocessing
import os
from pathlib import Path
import numpy
import torch
from benchmark_helper import create_onnxruntime_session
from gpt2_helper import Gpt2Helper
from onnx import TensorProto, numpy_helper
NON_ZERO_VALUE = str(1)
ZERO_VALUE = str(0)
def environ_setting_nodes(node_name_filter=None, node_type_filter=None):
# Set I/O data as default
os.environ["ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA"] = ZERO_VALUE
os.environ["ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA"] = NON_ZERO_VALUE
os.environ["ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA"] = NON_ZERO_VALUE
if node_name_filter is not None:
os.environ["ORT_DEBUG_NODE_IO_NAME_FILTER"] = node_name_filter
elif node_type_filter is not None:
os.environ["ORT_DEBUG_NODE_IO_OP_TYPE_FILTER"] = node_type_filter
else:
os.environ["ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK"] = NON_ZERO_VALUE
def environ_setting_paths(output_path):
# Set dumping values to files as default
os.environ["ORT_DEBUG_NODE_IO_DUMP_DATA_DESTINATION"] = "files"
os.environ["ORT_DEBUG_NODE_IO_OUTPUT_DIR"] = output_path
def environ_reset():
for flag in [
"ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA",
"ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA",
"ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA",
"ORT_DEBUG_NODE_IO_NAME_FILTER",
"ORT_DEBUG_NODE_IO_OP_TYPE_FILTER",
"ORT_DEBUG_NODE_IO_DUMP_DATA_TO_FILES",
"ORT_DEBUG_NODE_IO_OUTPUT_DIR",
"ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK",
]:
if flag in os.environ:
del os.environ[flag]
def inference(model_path, dummy_inputs, outputs_path, use_gpu):
environ_reset()
environ_setting_nodes()
environ_setting_paths(outputs_path)
session = create_onnxruntime_session(model_path, use_gpu, enable_all_optimization=False)
Gpt2Helper.onnxruntime_inference(session, dummy_inputs)
def generate_outputs_files(model_path, dummy_inputs, outputs_path, use_gpu):
dir_path = Path(outputs_path)
if dir_path.exists() and dir_path.is_dir():
import shutil # noqa: PLC0415
shutil.rmtree(outputs_path)
dir_path.mkdir(parents=True, exist_ok=True)
process = multiprocessing.Process(target=inference, args=(model_path, dummy_inputs, outputs_path, use_gpu))
process.start()
process.join()
def post_processing(outputs_path, outputs_path_other):
# Compare outputs with e.g. fp16 and fp32
record = {}
if_close = {}
import glob # noqa: PLC0415
for filename in glob.glob(os.path.join(outputs_path, "*.tensorproto")):
filename_other = os.path.join(outputs_path_other, Path(filename).name)
if not os.path.exists(filename_other):
continue
with open(filename, "rb") as f:
tensor = TensorProto()
tensor.ParseFromString(f.read())
array = numpy_helper.to_array(tensor)
with open(filename_other, "rb") as f: # noqa: PLW2901
tensor_other = TensorProto()
tensor_other.ParseFromString(f.read())
array_other = numpy_helper.to_array(tensor_other)
if array_other.size == 0:
continue
diff = numpy.average(numpy.abs(array_other - array) / (numpy.abs(array_other) + 1e-6))
if math.isnan(diff):
continue
record[Path(filename).name.split(".")[0]] = diff
if_close[Path(filename).name.split(".")[0]] = numpy.allclose(array, array_other, rtol=1e-04, atol=1e-04)
results = ["Node\tDiff\tClose"]
for k, v in sorted(record.items(), key=lambda x: x[1], reverse=True):
results.append(f"{k}\t{v}\t{if_close[k]}")
for line in results:
print(line)
if __name__ == "__main__":
# Below example shows how to use this helper to investigate parity issue of gpt-2 fp32 and fp16 onnx model
# Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON !!
multiprocessing.set_start_method("spawn")
# Generate Inputs
sequence_length = 8
past_sequence_length = 8
batch_size = 5
dummy_inputs_fp16 = Gpt2Helper.get_dummy_inputs(
batch_size,
past_sequence_length,
sequence_length,
12,
768,
12,
50257,
device=torch.device("cpu"),
float16=True,
)
dummy_inputs_fp32 = dummy_inputs_fp16.to_fp32()
# Get GPT-2 model from huggingface using convert_to_onnx.py
os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp32.onnx -o -p fp32 --use_gpu")
os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp16.onnx -o -p fp16 --use_gpu")
# Specify the directory to dump the node's I/O
outputs_path_fp32_gpu = "./fp32_gpu"
outputs_path_fp16_gpu = "./fp16_gpu"
generate_outputs_files("./gpt2_fp32.onnx", dummy_inputs_fp32, outputs_path_fp32_gpu, use_gpu=True)
generate_outputs_files("./gpt2_fp16.onnx", dummy_inputs_fp16, outputs_path_fp16_gpu, use_gpu=True)
# Compare each node's I/O value and sort based on average rtol
post_processing(outputs_path_fp16_gpu, outputs_path_fp32_gpu)
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,703 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import gc
import itertools
import logging
import os
import sys
import time
import numpy as np
import onnx
import psutil
import torch
from benchmark_helper import measure_memory, setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings_as_ortvalues,
get_merged_sample_with_past_kv_inputs,
get_msft_sample_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from optimum.onnxruntime import ORTModelForCausalLM
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import onnxruntime as ort
logger = logging.getLogger(__name__)
# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
def get_ort_model_inputs_len(args, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
return 0
if args.benchmark_type == "hf-ort":
try:
# New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
return len(model.inputs_names)
except Exception:
# Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
return len(model.decoder.input_names)
return len(model.get_inputs())
def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
init_inputs, iter_inputs = None, None
# For past_present_share_buffer:
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
# Set max_seq_len to config value for other models
max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
elif args.benchmark_type in {"hf-ort"}:
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
# Using split models in Optimum (e.g. created by Optimum export)
init_inputs = get_sample_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
return_dict=True,
)
iter_inputs = get_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
args.sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
)
else:
# Using merged model in Optimum (e.g. created by convert_to_onnx export)
init_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="pt",
return_dict=True,
)
elif args.benchmark_type == "ort-convert-to-onnx":
# Microsoft export from convert_to_onnx
init_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
iter_inputs = get_merged_sample_with_past_kv_inputs(
args.config,
args.target_device,
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
engine="ort",
return_dict=True,
world_size=args.world_size,
)
elif args.benchmark_type == "ort-msft":
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
init_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
iter_inputs = get_msft_sample_inputs(
args.config,
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
split_kv=split_kv,
)
else:
raise Exception("Unable to auto-detect inputs for provided model")
return init_inputs, iter_inputs
def get_model(args: argparse.Namespace):
model, sess_options = None, None
start_time, end_time = None, None
# There are multiple sources that the model could come from:
# 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
# 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
# 3) Benchmark LLaMA-2 from local download of model
# 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
# 5) Benchmark LLaMA-2 from convert_to_onnx
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
source,
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_cache=True,
cache_dir=args.cache_dir,
).to(args.target_device)
end_time = time.time()
if args.benchmark_type == "hf-pt-compile":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
if args.benchmark_type == "hf-ort":
# Optimum export or convert_to_onnx.py export
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
decoder_file_name = None
decoder_with_past_file_name = None
for filename in os.listdir(args.hf_ort_dir_path):
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
continue
if "decoder_model" in filename or filename == "model.onnx":
decoder_file_name = filename
if "decoder_with_past_model" in filename:
decoder_with_past_file_name = filename
if "decoder_merged_model" in filename:
decoder_file_name = filename
decoder_with_past_file_name = filename
start_time = time.time()
model = ORTModelForCausalLM.from_pretrained(
args.hf_ort_dir_path,
decoder_file_name=decoder_file_name,
decoder_with_past_file_name=decoder_with_past_file_name,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
use_merged=(True if decoder_file_name == "model.onnx" else None),
provider=provider,
provider_options=provider_options,
session_options=sess_options,
)
end_time = time.time()
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
# Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path.format(args.rank),
sess_options,
providers=[args.execution_provider],
)
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
return model
def time_fn(args, fn, inputs):
# Warm up
warmup_range = (
range(args.warmup_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
if args.verbose:
outputs = fn(inputs)
logger.info(outputs)
input_sync = lambda *kwargs: ( # noqa: E731
args.io_binding.synchronize_inputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: (
torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None
)
) # no-op function
output_sync = lambda *kwargs: ( # noqa: E731
args.io_binding.synchronize_outputs()
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
else lambda *kwargs: (
torch.cuda.synchronize()
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
else lambda *kwargs: None
)
) # no-op function
for _ in warmup_range:
input_sync()
fn(inputs)
output_sync()
# Benchmark
total_time = 0
bench_range = (
range(args.num_runs)
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
input_sync()
start_time = time.time()
fn(inputs)
output_sync()
end_time = time.time()
total_time += end_time - start_time
# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
logger.info("")
latency = total_time / args.num_runs
throughput = args.batch_size / latency
if args.rank == 0:
logger.info(f"Batch Size: {args.batch_size}")
logger.info(f"Sequence Length: {args.sequence_length}")
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} tps")
return
def profile_fn(args, fn, inputs, inputs_type):
# Filename prefix format:
# "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
with record_function("model_inference"):
fn(inputs)
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
filename = os.path.join(args.log_folder, f"{prefix}.log")
with open(filename, "w") as f:
f.write(prof_data)
else:
# Profile ORT kernels
fn(inputs)
# Set new log name for ORT profile log generated
filename = f"{prefix}.json"
return filename
def measure_fn(args, fn, inputs):
# Measure CPU usage
pid = os.getpid()
process = psutil.Process(pid)
process.cpu_percent(interval=0.1)
fn(inputs)
if args.rank == 0:
logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
# Flush output so memory usage is printed
sys.stdout.flush()
def run_hf_inference(args, init_inputs, iter_inputs, model):
# Inference steps to measure
def get_logits(inputs):
# Inference pass without decoding
outputs = model(**inputs)
return outputs
# Examples of other inference steps that can be measured:
# To use, uncomment the function and assign it to `generate_fn`
# def get_pred_ids(inputs):
# # Inference pass with predicted token ids generation
# predicted_ids = model.generate(**inputs)
# return predicted_ids
# def gen_and_dec(inputs):
# # Inference pass with generation and decoding
# predicted_ids = get_pred_ids(inputs)
# transcription = []
# for bs in range(args.batch_size):
# for rs in range(args.num_return_sequences):
# transcription.append(
# args.tokenizer.batch_decode(
# predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
# )[0]
# )
# return transcription
generate_fn = get_logits
if args.benchmark_type == "hf-pt-compile":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(init_inputs)
generate_fn(iter_inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
if args.benchmark_type == "hf-ort":
# Turn profiling off to stop appending to log
old_logname = model.decoder_with_past.session.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# PyTorch evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
time_fn(args, generate_fn, init_inputs)
measure_fn(args, generate_fn, init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
time_fn(args, generate_fn, iter_inputs)
measure_fn(args, generate_fn, iter_inputs)
def run_ort_inference(args, init_inputs, iter_inputs, model):
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
# Verify model inputs
inputs = verify_ort_inputs(model, inputs)
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
)
setattr(args, "io_binding", io_binding) # noqa: B010
return io_binding, kv_cache_ortvalues
return inputs, kv_cache_ortvalues
def with_io_binding(io_binding):
# Inference pass with IO binding
model.run_with_iobinding(io_binding)
def without_io_binding(inputs):
# Inference pass without IO binding
outputs = model.run(None, inputs)
return outputs
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
kv_cache_ortvalues = {}
if args.profile:
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
# Turn profiling off to stop appending to log file
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
# Re-initialize model for new log file instead of appending to old log file
model = get_model(args)
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
# Turn profiling off to stop appending to log
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# ORT evaluations
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_init_inputs)
measure_fn(args, generate_fn, ort_init_inputs)
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
time_fn(args, generate_fn, ort_iter_inputs)
measure_fn(args, generate_fn, ort_iter_inputs)
def run_inference(args, init_inputs, iter_inputs, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
run_hf_inference(args, init_inputs, iter_inputs, model)
elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
run_ort_inference(args, init_inputs, iter_inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
def get_args(rank=0):
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=[
"hf-pt-eager",
"hf-pt-compile",
"hf-ort",
"ort-msft",
"ort-convert-to-onnx",
],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=True,
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
)
parser.add_argument(
"-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
)
# Args for choosing the model
parser.add_argument(
"-p",
"--precision",
required=True,
type=str,
default="fp32",
choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"--hf-pt-dir-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
type=str,
default="",
help="Path to ONNX model",
)
# Args for running and evaluating the model
parser.add_argument(
"-b",
"--batch-sizes",
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
default="32 64 128 256 512",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=10)
parser.add_argument("--seed", type=int, default=2)
# Args for decoding logic
parser.add_argument("--max-length", type=int, default=32)
parser.add_argument("--num-return-sequences", type=int, default=1)
# Args for accessing detailed info
parser.add_argument("--profile", default=False, action="store_true")
parser.add_argument(
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
)
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
parser.add_argument(
"--cache-dir",
type=str,
required=True,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
args.device = "cuda"
# Check that paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.sequence_lengths = args.sequence_lengths.split(" ")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
# Check that only one (batch_size, sequence_length) combination is set for profiling
if args.profile:
assert len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1, (
"Please provide only one (batch_size, sequence_length) combination for profiling"
)
return args
def main():
rank = get_rank()
world_size = get_size()
args = get_args(rank)
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
args.rank = rank
args.world_size = world_size
tokenizer = AutoTokenizer.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
)
config = AutoConfig.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
)
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "tokenizer", tokenizer) # noqa: B010
setattr(args, "config", config) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
# Get model and model info
model = get_model(args)
ort_model_inputs_len = get_ort_model_inputs_len(args, model)
# Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
else:
setattr(args, "use_buffer_share", False) # noqa: B010
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
if args.rank == 0:
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
setattr(args, "batch_size", int(batch_size)) # noqa: B010
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
run_inference(args, init_inputs, iter_inputs, model)
if __name__ == "__main__":
main()
@@ -0,0 +1,488 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import json
import logging
import os
import subprocess
import torch
from benchmark_helper import setup_logger
from metrics import BenchmarkRecord
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-b",
"--batch-sizes",
type=str,
default="1 2",
)
parser.add_argument(
"-s",
"--sequence-lengths",
type=str,
default="8 16 32 64 128 256 512",
)
parser.add_argument(
"-w",
"--warmup-runs",
type=int,
default=5,
)
parser.add_argument(
"-n",
"--num-runs",
type=int,
default=1000,
)
parser.add_argument(
"--hf-pt-eager",
default=False,
action="store_true",
help="Benchmark in PyTorch without `torch.compile`",
)
parser.add_argument(
"--hf-pt-compile",
default=False,
action="store_true",
help="Benchmark in PyTorch with `torch.compile`",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
)
parser.add_argument(
"--ort-msft-model-path",
type=str,
default="",
help="Path to ONNX model from https://github.com/microsoft/Llama-2-Onnx",
)
parser.add_argument(
"--ort-convert-to-onnx-model-path",
type=str,
default="",
help="Path to ONNX model from convert_to_onnx",
)
parser.add_argument(
"--cache-dir",
type=str,
default="./model_cache",
help="Cache dir where Hugging Face files are stored",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name in Hugging Face",
)
parser.add_argument(
"--precision",
type=str,
required=True,
choices=["int4", "int8", "fp16", "fp32"],
help="Precision to run model",
)
parser.add_argument(
"--device",
type=str,
required=True,
choices=["cpu", "cuda", "rocm"],
help="Device to benchmark models",
)
parser.add_argument(
"--device-id",
type=int,
default=0,
help="GPU device ID",
)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="Print detailed logs",
)
parser.add_argument(
"--timeout",
type=int,
default=10,
help="Number of mins to attempt the benchmark before moving on",
)
parser.add_argument(
"--log-folder",
type=str,
default=None,
help="Path to folder to save logs and results",
)
args = parser.parse_args()
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
log_folder_name = f"./{args.model_size}_{args.precision}"
if not args.log_folder:
args.log_folder = log_folder_name
os.makedirs(args.log_folder, exist_ok=True)
# Convert timeout value to secs
args.timeout *= 60
return args
def process_log_file(device_id, log_file, base_results):
entries = []
batch_size, sequence_length, step = None, None, None
latency_s, latency_ms, throughput, memory = None, None, None, None
batch_pattern = "Batch Size: "
sequence_pattern = "Sequence Length: "
prompt_step_pattern = "to get past_key_values"
per_token_step_pattern = "with past_key_values"
latency_pattern = "Latency: "
throughput_pattern = "Throughput: "
memory_pattern = "peak="
with open(log_file) as f:
for input_line in f:
line = input_line.replace("\n", "")
if batch_pattern in line:
batch_size = int(line[len(batch_pattern) :])
elif sequence_pattern in line:
sequence_length = int(line[len(sequence_pattern) :])
elif prompt_step_pattern in line:
step = "prompt"
elif per_token_step_pattern in line:
step = "per-token"
elif latency_pattern in line:
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
latency_ms = latency_s * 1000
elif throughput_pattern in line:
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
elif memory_pattern in line:
if "CPU" in line:
# Example format for log entry:
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
else:
# Example format for log entry:
# GPU memory usage: before=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 69637.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}] peak=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 73861.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}]
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
usage = json.loads(peak)[device_id]["max_used_MB"]
memory = float(usage) / 1000
# Append log entry to list of entries
entry = base_results + [ # noqa: RUF005
batch_size,
sequence_length,
step,
latency_s,
latency_ms,
throughput,
memory,
]
entries.append(entry)
return entries
def save_results(results, filename):
import pandas as pd # noqa: PLC0415
df = pd.DataFrame(
results,
columns=[
"Warmup Runs",
"Measured Runs",
"Model Name",
"Engine",
"Precision",
"Device",
"Batch Size",
"Sequence Length",
"Step",
"Latency (s)",
"Latency (ms)",
"Throughput (tps)",
"Memory (GB)",
],
)
# Set column types
df["Warmup Runs"] = df["Warmup Runs"].astype("int")
df["Measured Runs"] = df["Measured Runs"].astype("int")
df["Batch Size"] = df["Batch Size"].astype("int")
df["Sequence Length"] = df["Sequence Length"].astype("int")
df["Latency (s)"] = df["Latency (s)"].astype("float")
df["Latency (ms)"] = df["Latency (ms)"].astype("float")
df["Throughput (tps)"] = df["Throughput (tps)"].astype("float")
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
# get package name and version
import pkg_resources # noqa: PLC0415
installed_packages = pkg_resources.working_set
installed_packages_list = sorted(
[f"{i.key}=={i.version}" for i in installed_packages if i.key in ["onnxruntime", "onnxruntime-gpu"]]
)
ort_pkg_name = ""
ort_pkg_version = ""
if installed_packages_list:
ort_pkg_name = installed_packages_list[0].split("==")[0]
ort_pkg_version = installed_packages_list[0].split("==")[1]
# Save results to csv with standard format
records = []
for _, row in df.iterrows():
if row["Engine"] in ["optimum-ort", "onnxruntime"]:
record = BenchmarkRecord(
row["Model Name"], row["Precision"], "onnxruntime", row["Device"], ort_pkg_name, ort_pkg_version
)
elif row["Engine"] in ["pytorch-eager", "pytorch-compile"]:
record = BenchmarkRecord(
row["Model Name"], row["Precision"], "pytorch", row["Device"], torch.__name__, torch.__version__
)
else:
record = BenchmarkRecord(row["Model Name"], row["Precision"], row["Engine"], row["Device"], "", "")
record.config.warmup_runs = row["Warmup Runs"]
record.config.measured_runs = row["Measured Runs"]
record.config.batch_size = row["Batch Size"]
record.config.seq_length = row["Sequence Length"]
record.config.customized["measure_step"] = row["Step"]
record.config.customized["engine"] = row["Engine"]
record.metrics.customized["latency_s_mean"] = row["Latency (s)"]
record.metrics.latency_ms_mean = row["Latency (ms)"]
record.metrics.customized["throughput_tps"] = row["Throughput (tps)"]
record.metrics.max_memory_usage_GB = row["Memory (GB)"]
records.append(record)
BenchmarkRecord.save_as_csv(filename, records)
BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
logger.info(f"Results saved in {filename}!")
def benchmark(args, benchmark_cmd, engine):
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
log_path = os.path.join(args.log_folder, log_filename)
with open(log_path, "w") as log_file:
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
try:
process.wait(args.timeout)
except subprocess.TimeoutExpired:
process.kill()
# Create entries for csv
logger.info("Gathering data from log files...")
base_results = [args.warmup_runs, args.num_runs, args.model_name, engine, args.precision, args.device]
results = process_log_file(args.device_id, log_path, base_results)
return results
def main():
args = get_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
all_results = []
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
# Benchmark PyTorch without torch.compile
if args.hf_pt_eager:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-pt-eager",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark PyTorch without torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-eager")
all_results.extend(results)
# Benchmark PyTorch with torch.compile
if args.hf_pt_compile:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-pt-compile",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark PyTorch with torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-compile")
all_results.extend(results)
# Benchmark Optimum + ONNX Runtime
if args.hf_ort_dir_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"hf-ort",
"--hf-ort-dir-path",
args.hf_ort_dir_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
"--auth",
]
logger.info("Benchmark Optimum + ONNX Runtime")
results = benchmark(args, benchmark_cmd, "optimum-ort")
all_results.extend(results)
# Benchmark Microsoft model in ONNX Runtime
if args.ort_msft_model_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"ort-msft",
"--ort-model-path",
args.ort_msft_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
]
logger.info("Benchmark Microsoft model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "ort-msft")
all_results.extend(results)
# Benchmark convert_to_onnx model in ONNX Runtime
if args.ort_convert_to_onnx_model_path:
benchmark_cmd = [
"python",
"-m",
"models.llama.benchmark",
"--benchmark-type",
"ort-convert-to-onnx",
"--ort-model-path",
args.ort_convert_to_onnx_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--batch-sizes",
args.batch_sizes,
"--sequence-lengths",
args.sequence_lengths,
"--device",
args.device,
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
"--cache-dir",
args.cache_dir,
]
logger.info("Benchmark convert_to_onnx model in ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime")
all_results.extend(results)
csv_file = f"{args.model_size}_{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_results, os.path.join(args.log_folder, csv_file))
if __name__ == "__main__":
main()
@@ -0,0 +1,608 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This is an end-to-end benchmarking script for the Hugging Face LLaMA-2 model.
#
# Prerequisites:
# 1) Install `huggingface-cli`:
#
# $ pip install huggingface_hub
#
# 2) Authenticate with Hugging Face's CLI:
#
# $ huggingface-cli login
#
# 3) Accept Meta's license in Hugging Face to access the models at https://huggingface.co/meta-llama/
#
# 4) Install the latest ONNX Runtime version
#
# $ pip install onnxruntime-gpu
#
# 5) Install flash attention v2
#
# $ pip install flash-attn --no-build-isolation
#
# 6) Install bitsandbytes
#
# $ pip install bitsandbytes
from __future__ import annotations
import argparse
import datetime
import gc
import itertools
import json
import logging
import os
import textwrap
import time
import numpy as np
import pandas as pd
import torch
from benchmark_helper import setup_logger
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import onnxruntime as ort
logger = logging.getLogger(__name__)
def get_model(args: argparse.Namespace):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
model = None
if args.onnx_precision == "int4" and args.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
max_memory={args.device_id: "80GB"},
)
else:
try:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
except Exception as e:
# When flash_attention or sdpa doesn't support a model, it throws an exception.
# Rather than stopping a process, run as eager mode.
print("Try to load a model using eager mode: ", e)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="eager",
).to(args.target_device)
model.eval()
if args.benchmark_type == "pt-compile":
model = torch.compile(model)
else:
sess_options = ort.SessionOptions()
ep = (
("CUDAExecutionProvider", {"device_id": args.device_id})
if args.device == "cuda"
else "CPUExecutionProvider"
)
model = ort.InferenceSession(args.onnx_model_path, sess_options=sess_options, providers=[ep])
return model
def run_inference(args, model, runs, inputs, outputs):
if args.benchmark_type == "pt-compile":
with torch.no_grad():
outputs = model(**inputs)
# Synchronize inputs
io_binding = None
if args.benchmark_type in {"pt-eager", "pt-compile"}:
if args.device != "cpu":
torch.cuda.synchronize(args.target_device)
else:
io_binding = add_io_bindings_as_tensors(model, inputs, outputs, args.use_fp16, args.use_buffer_share)
io_binding.synchronize_inputs()
# Run inference
start = time.perf_counter()
for _ in range(runs):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
with torch.no_grad():
outputs = model(**inputs)
if args.device != "cpu":
torch.cuda.synchronize(args.target_device)
else:
model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
end = time.perf_counter()
avg = (end - start) / runs
return avg, outputs
def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt):
clear_cache()
inputs, outputs = get_initial_inputs_and_outputs(
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine
)
_, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs)
return inputs, outputs
def clear_cache():
gc.collect()
torch.cuda.empty_cache()
def save_results(results, filename, gen_length):
df = pd.DataFrame(
results,
columns=[
"Batch Size",
"Prompt Length",
"Prompt Processing Latency (ms)",
"Prompt Processing Throughput (tps)",
"Sampling Latency (ms)",
"Sampling Throughput (tps)",
"First Token Generated Latency (ms)",
"First Token Generated Throughput (tps)",
f"Average Latency of First {gen_length // 2} Tokens Generated (ms)",
f"Average Throughput of First {gen_length // 2} Tokens Generated (tps)",
f"Average Latency of First {gen_length} Tokens Generated (ms)",
f"Average Throughput of First {gen_length} Tokens Generated (tps)",
"Wall-Clock Latency (s)",
"Wall-Clock Throughput (tps)",
],
)
df.to_csv(filename, index=False)
logger.info(f"Results saved in {filename}!")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=["pt-eager", "pt-compile", "ort"],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=False,
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
)
parser.add_argument(
"-a",
"--auth",
default=False,
action="store_true",
help="Use Hugging Face authentication token to access model",
)
parser.add_argument(
"-t",
"--trust",
default=False,
action="store_true",
help="Whether or not to allow for custom models defined on the Hugging Face Hub in their own modeling files",
)
parser.add_argument(
"-c",
"--cache-dir",
type=str,
default=os.path.join(".", "model_cache"),
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(model_name, cache_dir=cache_dir)`.",
)
parser.add_argument(
"--hf-dir-path",
type=str,
default="",
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(folder_path)`.",
)
parser.add_argument(
"-o",
"--onnx-model-path",
required=False,
help="Path to ONNX model",
)
parser.add_argument(
"-f",
"--prompts-file",
required=True,
default=os.path.join(".", "models", "llama", "prompts.json"),
help="JSON file containing entries in the format 'prompt length: prompt' where prompt length = tokenized length of prompt",
)
parser.add_argument(
"--use_buffer_share",
default=False,
action="store_true",
help="Use when GroupQueryAttention (GQA) is in ONNX model",
)
(
parser.add_argument(
"--anomaly-filtering",
default=False,
action="store_true",
help="Use this flag to filter anomaly accelerator times for tokens generated. \
This may give more accurate latency and throughput metrics for tokens generated. \
Wall-clock metrics are still reported with anomaly times though.",
),
)
parser.add_argument(
"-b",
"--batch-sizes",
default="1 2",
)
parser.add_argument(
"-s",
"--prompt-lengths",
default="16 64 256 1024",
)
parser.add_argument(
"-p",
"--precision",
required=True,
type=str,
default="fp32",
choices=["int4", "int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"-g",
"--generation-length",
type=int,
default=256,
help="Number of new tokens to generate",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=100)
parser.add_argument("--seed", type=int, default=2)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
# Check that paths have been specified for any benchmarking with ORT
if args.benchmark_type == "ort":
assert args.onnx_model_path, "Please specify a path to `--onnx-model-path`"
args.batch_sizes = args.batch_sizes.split(" ")
args.prompt_lengths = args.prompt_lengths.split(" ")
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
setattr(args, "onnx_precision", args.precision) # noqa: B010
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
torch_dtype = torch.float16 if args.precision == "fp16" else torch.float32
engine = "ort" if args.benchmark_type == "ort" else "pt"
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "torch_dtype", torch_dtype) # noqa: B010
setattr(args, "engine", engine) # noqa: B010
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
args.use_buffer_share = args.use_buffer_share and engine == "ort"
return args
def main():
args = get_args()
setup_logger(False)
logger.info(args.__dict__)
# Get prompts and prompt sizes
size_to_prompt = None
with open(args.prompts_file) as f:
size_to_prompt = json.load(f, object_hook=lambda d: {int(k): v for k, v in d.items()})
# Get config, tokenizer, and model
config = AutoConfig.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.trust,
)
tokenizer = AutoTokenizer.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.trust,
)
model = get_model(args)
all_csv_metrics = []
for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths):
batch_size, prompt_length = int(batch_size), int(prompt_length) # noqa: PLW2901
logger.info(f"Running batch size = {batch_size}, prompt length = {prompt_length}")
clear_cache()
max_length = prompt_length + args.generation_length
if prompt_length not in size_to_prompt:
raise NotImplementedError(
textwrap.dedent(
f"""
A prompt of size {prompt_length} was not found in '{args.prompts_file}'. There are a couple of solutions to fix this.
1) You can change one of the keys in '{args.prompts_file}' to be {prompt_length}.
If {prompt_length} < actual prompt's length, the benchmark E2E tool will repeat the first word in the prompt until {prompt_length} = actual prompt's length.
If {prompt_length} > actual prompt's length, the benchmark E2E tool will automatically trim the actual prompt's length so that {prompt_length} = actual prompt's length.
2) You can add a new key-value entry in '{args.prompts_file}' of the form '{prompt_length}': 'your prompt goes here'.
"""
)
)
prompt = [size_to_prompt[prompt_length]] * batch_size
csv_metrics = [batch_size, prompt_length]
try:
# Measure prompt processing
logger.info("Measuring prompt processing...")
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs)
# Calculate prompt metrics
accelerator_prompt_latency_ms = accelerator_prompt_latency_s * 1000
accelerator_prompt_thrpt = batch_size * (prompt_length / accelerator_prompt_latency_s)
logger.info(f"Average Latency of Prompt Processing: {accelerator_prompt_latency_ms} ms")
logger.info(
f"Average Throughput of Prompt Processing: {batch_size * (prompt_length / accelerator_prompt_latency_s)} tps"
)
csv_metrics.extend([accelerator_prompt_latency_ms, accelerator_prompt_thrpt])
# Measure token generation
logger.info("Measuring token generation...")
clear_cache()
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
all_token_ids = inputs["input_ids"].clone()
current_length = all_token_ids.shape[-1]
num_heads = config.num_key_value_heads
head_size = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
has_eos = torch.zeros(batch_size, device=args.target_device, dtype=torch.bool)
# 0th entry will have prompt accelerator time, 1st entry onwards will have token generation accelerator time
accelerator_times = []
sampling_times = [] # cost to sample after each model run
wall_clock_start_time = time.perf_counter()
while current_length <= max_length:
# Run inference
accelerator_time_latency_s, outputs = run_inference(args, model, 1, inputs, outputs)
accelerator_times.append(accelerator_time_latency_s)
# Sample with argmax (greedy search)
sampling_start_time = time.perf_counter()
if outputs["logits"].shape[1] > 1:
prompt_end_indices = inputs["attention_mask"].sum(1) - 1
idxs = (
prompt_end_indices.unsqueeze(dim=1)
.repeat(1, config.vocab_size)
.view(batch_size, 1, config.vocab_size)
)
next_token_logits = torch.gather(outputs["logits"], 1, idxs).squeeze()
else:
next_token_logits = outputs["logits"][:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Check if we previously reached EOS token id or if generated token id is EOS token id
has_eos = has_eos | next_tokens == tokenizer.eos_token_id
# Determine which new tokens to add to list of all token ids
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
tokens_to_add = next_tokens.masked_fill(has_eos, tokenizer.eos_token_id).reshape([batch_size, 1])
sampling_end_time = time.perf_counter()
sampling_times.append(sampling_end_time - sampling_start_time)
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
current_length += 1
# Update inputs for next inference run
inputs["input_ids"] = tokens_to_add
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
)
if "position_ids" in inputs:
inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
outputs["logits"].zero_()
# Update KV caches for next inference run
if args.engine == "pt":
# Update KV caches for PyTorch
inputs["past_key_values"] = outputs["past_key_values"]
elif not args.use_buffer_share:
# Update KV caches for ONNX Runtime if buffer sharing is not used
for i in range(config.num_hidden_layers):
inputs[f"past_key_values.{i}.key"] = outputs[f"present.{i}.key"]
inputs[f"past_key_values.{i}.value"] = outputs[f"present.{i}.value"]
new_sequence_length = inputs["attention_mask"].shape[1]
for i in range(config.num_hidden_layers):
present_key = torch.zeros(
batch_size,
num_heads,
new_sequence_length,
head_size,
device=args.target_device,
dtype=args.torch_dtype,
)
present_value = torch.zeros(
batch_size,
num_heads,
new_sequence_length,
head_size,
device=args.target_device,
dtype=args.torch_dtype,
)
outputs.update(
{
f"present.{i}.key": present_key.contiguous(),
f"present.{i}.value": present_value.contiguous(),
}
)
wall_clock_end_time = time.perf_counter()
# Filter out any anomaly accelerator times (e.g. for `torch.compile`)
accelerator_times.pop(0) # Remove prompt processing time
if args.anomaly_filtering:
anomaly_threshold_factor = 10
min_time_s = min(accelerator_times)
orig_size = len(accelerator_times)
accelerator_times = list(
filter(lambda acc_time: acc_time < anomaly_threshold_factor * min_time_s, accelerator_times)
)
new_size = len(accelerator_times)
logger.info(
f"Filtered out {orig_size - new_size} anomaly accelerator times that are {anomaly_threshold_factor}x greater than {min_time_s * 1000} ms..."
)
#######################################################
# Calculate sampling and first token generated metrics
#######################################################
# Calculate sampling metrics
avg_sampling_latency_s = sum(sampling_times) / len(sampling_times)
avg_sampling_latency_ms = avg_sampling_latency_s * 1000
avg_sampling_thrpt = batch_size * (1 / avg_sampling_latency_s)
logger.info(f"Average Latency of Sampling: {avg_sampling_latency_ms} ms")
logger.info(f"Average Throughput of Sampling: {avg_sampling_thrpt} tps")
# Calculate first token generated metrics
first_token_latency_s = accelerator_times[0]
first_token_latency_ms = first_token_latency_s * 1000
first_token_thrpt = batch_size * (1 / first_token_latency_s)
logger.info(f"Latency of First Token Generated: {first_token_latency_ms} ms")
logger.info(f"Throughput of First Token Generated: {first_token_thrpt} tps")
####################################################
# Calculate first `halfway` token generated metrics
####################################################
halfway = args.generation_length // 2
halfway_token_latency_s = sum(accelerator_times[:halfway]) / len(accelerator_times[:halfway])
halfway_token_latency_ms = halfway_token_latency_s * 1000
halfway_token_thrpt = batch_size * (1 / halfway_token_latency_s)
logger.info(f"Average Latency of First {halfway} Tokens Generated: {halfway_token_latency_ms} ms")
logger.info(f"Average Throughput of First {halfway} Tokens Generated: {halfway_token_thrpt} tps")
#########################################
# Calculate all tokens generated metrics
#########################################
all_token_latency_s = sum(accelerator_times) / len(accelerator_times)
all_token_latency_ms = all_token_latency_s * 1000
all_token_thrpt = batch_size * (1 / all_token_latency_s)
logger.info(
f"Average Latency of First {args.generation_length} Tokens Generated: {all_token_latency_ms} ms"
)
logger.info(f"Average Throughput of First {args.generation_length} Tokens Generated: {all_token_thrpt} tps")
###############################
# Calculate wall clock metrics
###############################
wall_clock_latency_s = wall_clock_end_time - wall_clock_start_time
wall_clock_thrpt = batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)
logger.info(f"Wall-Clock Latency: {wall_clock_latency_s} s")
logger.info(
f"Wall-Clock Throughput: {batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)} tps"
)
# Add metrics to CSV
logger.info("Adding results to CSV")
csv_metrics.extend(
[
avg_sampling_latency_ms,
avg_sampling_thrpt,
first_token_latency_ms,
first_token_thrpt,
halfway_token_latency_ms,
halfway_token_thrpt,
all_token_latency_ms,
all_token_thrpt,
wall_clock_latency_s,
wall_clock_thrpt,
]
)
all_csv_metrics.append(csv_metrics)
except Exception as e:
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}")
filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_csv_metrics, filename, args.generation_length)
if __name__ == "__main__":
main()
@@ -0,0 +1,57 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import os
import torch.distributed as dist
def init_dist():
if "LOCAL_RANK" in os.environ:
int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
else:
# don't need to do init for single process
pass
def _get_comm():
try:
from mpi4py import MPI # noqa: PLC0415
comm = MPI.COMM_WORLD
return comm
except ImportError:
return None
def get_rank():
comm = _get_comm()
return comm.Get_rank() if comm is not None else 0
def get_size():
comm = _get_comm()
return comm.Get_size() if comm is not None else 1
def barrier():
comm = _get_comm()
if comm is not None:
comm.Barrier()
def print_out(*args):
if get_rank() == 0:
print(*args)
@@ -0,0 +1,504 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer
from transformers.cache_utils import DynamicCache
from onnxruntime import InferenceSession, OrtValue
# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if use_past_kv:
# Shape: (batch_size, 1)
position_ids = position_ids[:, -1].unsqueeze(-1)
# Shape: (batch_size, sequence_length)
return position_ids
# Inputs for first pass to get initial past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, sequence_length)
# position_ids: (batch_size, sequence_length)
def get_sample_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
seq_len: int,
engine: str = "pt",
return_dict: bool = False,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
position_ids = get_position_ids(attention_mask, use_past_kv=False)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
if not return_dict:
# For export
return (input_ids, attention_mask, position_ids)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return inputs
# Inputs for subsequent passes with past_key_values
# input_ids: (batch_size, 1)
# attention_mask: (batch_size, past_sequence_length + 1)
# position_ids: (batch_size, 1)
# past_key: (batch_size, num_heads, past_sequence_length, head_size)
# past_value: (batch_size, num_heads, past_sequence_length, head_size)
def get_sample_with_past_kv_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
past_seq_len: int,
use_fp16: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
# position_ids is of shape (batch_size, 1)
position_ids = get_position_ids(attention_mask, use_past_kv=True)
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
)
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv
return inputs
# Inputs for all passes with past_key_values
# input_ids: (batch_size, sequence_length)
# attention_mask: (batch_size, past_sequence_length + sequence_length)
# position_ids: (batch_size, sequence_length)
# past_key: (batch_size, num_heads, kv_sequence_length, head_size)
# For models with GQA, kv_sequence_length = max_sequence_length
# For models without GQA, kv_sequence_length = past_sequence_length
# past_value: (batch_size, num_heads, kv_sequence_length, head_size)
# For models with GQA, kv_sequence_length = max_sequence_length
# For models without GQA, kv_sequence_length = past_sequence_length
def get_merged_sample_with_past_kv_inputs(
config: AutoConfig,
device: torch.device,
batch_size: int,
seq_len: int,
past_seq_len: int,
max_seq_len: int,
use_fp16: bool = False,
use_buffer_share: bool = False,
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
past_kv = (
flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
)
if not return_dict:
# For export
assert isinstance(past_kv, list)
return (input_ids, attention_mask, position_ids, past_kv)
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
if use_buffer_share:
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
else:
assert isinstance(past_kv, list)
inputs["past_key_values"] = past_kv
return inputs
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
def get_msft_sample_inputs(
config: AutoConfig,
batch_size: int,
past_seq_len: int,
seq_len: int,
max_seq_len: int,
use_fp16: bool,
use_buffer_share: bool,
split_kv: bool,
):
np_dtype = np.float16 if use_fp16 else np.float32
head_size = config.hidden_size // config.num_attention_heads
if not split_kv:
ort_inputs = {
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
"attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
"k_cache": np.random.rand(
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
).astype(np_dtype),
"v_cache": np.random.rand(
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
).astype(np_dtype),
"pos": np.array(past_seq_len, dtype=np.int64),
}
else:
ort_inputs = {
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
"attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
np.int32
),
"pos": np.array(past_seq_len, dtype=np.int64),
}
for i in range(config.num_hidden_layers):
ort_inputs.update(
{
f"k_{i}_cache": np.random.rand(
batch_size, config.num_attention_heads, past_seq_len, head_size
).astype(np_dtype),
f"v_{i}_cache": np.random.rand(
batch_size, config.num_attention_heads, past_seq_len, head_size
).astype(np_dtype),
}
)
if use_buffer_share:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
return ort_inputs
# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
num_heads = config.num_key_value_heads // world_size
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
)
for _ in range(config.num_hidden_layers)
]
return past_kv
# Convert list of past_key_values to dict of past_key and past_value
def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
past_kv = {}
for i, (past_k, past_v) in enumerate(past_key_values):
if isinstance(past_key_values, DynamicCache):
past_kv[f"past_key_values_key_cache_{i}"] = past_k.detach().cpu().numpy()
past_kv[f"past_key_values_value_cache_{i}"] = past_v.detach().cpu().numpy()
else:
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
return past_kv
# Format PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
pt_inputs: dict,
use_buffer_share: bool = False,
past_seq_len: int = 0,
max_seq_len: int = 2048,
):
ort_inputs = {}
for k, v in pt_inputs.items():
if isinstance(v, np.ndarray):
ort_inputs[k] = v
elif k == "past_key_values":
ort_inputs.update(flatten_past_kv_inputs(v))
else:
ort_inputs[k] = v.detach().cpu().numpy()
# Reshape KV caches if using past-present-share-buffer
if use_buffer_share:
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
return ort_inputs
# Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
# (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
for k, v in ort_inputs.items():
# Allocate new buffers with max_sequence_length for GQA
if "cache" in k or "past_key_values" in k:
# Copy v (BxSxPxH) into new_v (BxSxMxH)
batch_size, num_heads, _, head_size = v.shape
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
ort_inputs[k] = new_v
return ort_inputs
# Verify ONNX Runtime inputs with model
def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
# Check that all model inputs will be provided
model_inputs = {model_input.name for model_input in model.get_inputs()}
user_inputs = set(ort_inputs.keys())
missing_inputs = model_inputs - user_inputs
if len(missing_inputs):
print(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")
# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
del ort_inputs[unnecessary_input]
return ort_inputs
# Add IO bindings for execution providers using OrtValue
# Use when you need to run inference once or twice to save memory
def add_io_bindings_as_ortvalues(
model: InferenceSession,
ort_inputs: dict,
device: str,
device_id: int,
use_buffer_share: bool,
kv_cache_ortvalues: dict,
):
io_binding = model.io_binding()
model_inputs = {i.name for i in model.get_inputs()}
for k, v in ort_inputs.items():
# Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
# GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
# but `position_ids` is used as a PyTorch model input
if k not in model_inputs:
continue
# Bind OrtValue inputs to device
if use_buffer_share and ("cache" in k or "past_key_values" in k):
if k not in kv_cache_ortvalues:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
kv_cache_ortvalues[k] = v_device
else:
kv_cache_ortvalues[k].update_inplace(v)
io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
else:
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
io_binding.bind_ortvalue_input(k, v_device)
for output in model.get_outputs():
name = output.name
if use_buffer_share and ("out" in name or "present" in name):
# Bind present KV cache outputs to past KV cache inputs in order to buffer share
input_name = name.replace("out", "cache").replace("present", "past_key_values")
io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
else:
io_binding.bind_output(name, device_type=device, device_id=device_id)
return io_binding, kv_cache_ortvalues
# Add IO bindings for execution providers using PyTorch tensors
# Use when you need to run inference many times
def add_io_bindings_as_tensors(
model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
):
# Verify model inputs
inputs = verify_ort_inputs(model, inputs)
device = None
pt_to_np = {
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.float16": np.float16,
"torch.float32": np.float32,
}
# Bind inputs/outputs to IO binding
io_binding = model.io_binding()
for k, v in inputs.items():
io_binding.bind_input(
name=k,
device_type=v.device.type,
device_id=0 if v.device.type == "cpu" else v.device.index,
element_type=pt_to_np[repr(v.dtype)],
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
device = v.device
for output in model.get_outputs():
name = output.name
# Bind KV cache outputs to KV cache inputs
v = (
inputs[name.replace("present", "past_key_values")]
if use_buffer_share and "present" in name
else outputs[name]
)
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
return io_binding
# Get actual inputs when using real data (instead of sample data) and initialize outputs
def get_initial_inputs_and_outputs(
config: AutoConfig,
tokenizer: AutoTokenizer,
requested_length: int,
prompt: list[str],
device: torch.device,
use_fp16: bool,
use_buffer_share: bool,
engine: str,
):
tokenizer.pad_token = tokenizer.eos_token
encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
torch_dtype = torch.float16 if use_fp16 else torch.float32
# input_ids: pad token id is 0
# attention_mask: pad token id is 0
# position_ids: pad token id is 1
input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
position_ids = get_position_ids(attention_mask, use_past_kv=False)
# Check if tokenized prompt length matches the requested prompt length
tokenized_length = input_ids.shape[-1]
if tokenized_length > requested_length:
# Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
input_ids = input_ids[:, :requested_length]
attention_mask = attention_mask[:, :requested_length]
position_ids = get_position_ids(attention_mask, use_past_kv=False)
elif tokenized_length < requested_length:
# Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
for _ in range(requested_length - tokenized_length):
input_ids = torch.hstack((input_ids_first_col, input_ids))
attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
position_ids = get_position_ids(attention_mask, use_past_kv=False)
tokenized_length = input_ids.shape[-1]
assert tokenized_length == requested_length
# Create inputs
inputs = {
"input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
"attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
"position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
}
if engine != "ort":
inputs["past_key_values"] = []
# Get shape of KV cache inputs
batch_size, sequence_length = input_ids.shape
max_sequence_length = config.max_position_embeddings
num_heads = config.num_key_value_heads
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
# Create KV cache inputs
for i in range(config.num_hidden_layers):
past_key = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
past_value = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
if engine == "ort":
inputs.update(
{
f"past_key_values.{i}.key": past_key.contiguous(),
f"past_key_values.{i}.value": past_value.contiguous(),
}
)
else:
inputs["past_key_values"].append((past_key, past_value))
outputs = None
if engine == "ort":
# Create outputs
logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
outputs = {"logits": logits.contiguous()}
if not use_buffer_share:
for i in range(config.num_hidden_layers):
present_key = torch.zeros(
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
)
present_value = torch.zeros(
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
)
outputs.update(
{f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
)
return inputs, outputs
@@ -0,0 +1,343 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations
import argparse
import logging
import os
import time
import numpy as np
import packaging.version as pv
import torch
from benchmark_helper import setup_logger
from dist_settings import get_rank, get_size
from llama_inputs import (
add_io_bindings_as_ortvalues,
convert_inputs_for_ort,
get_merged_sample_with_past_kv_inputs,
get_sample_inputs,
get_sample_with_past_kv_inputs,
verify_ort_inputs,
)
from llama_torch import setup_torch_model
from models.torch_export_patches.cache_helper import make_dynamic_cache
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from transformers.cache_utils import DynamicCache
import onnxruntime as ort
logger = logging.getLogger("")
def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
max_sequence_length = config.max_position_embeddings
return past_sequence_length, curr_sequence_length, max_sequence_length
def get_inputs(args: argparse.Namespace, config: AutoConfig):
# Dummy values for parity
world_size = get_size()
batch_size = 2
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)
if args.merged:
inputs = get_merged_sample_with_past_kv_inputs(
config,
args.device,
batch_size,
seq_len=sequence_length,
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
use_fp16=args.use_fp16,
use_buffer_share=args.use_buffer_share,
return_dict=True,
world_size=world_size,
)
elif args.use_past_kv:
inputs = get_sample_with_past_kv_inputs(
config,
args.device,
batch_size,
sequence_length,
use_fp16=args.use_fp16,
return_dict=True,
world_size=world_size,
)
else:
inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
return inputs
def torch_deepcopy(value):
if isinstance(value, (int, float, str)):
return value
if isinstance(value, tuple):
return tuple(torch_deepcopy(v) for v in value)
if isinstance(value, list):
return [torch_deepcopy(v) for v in value]
if isinstance(value, set):
return {torch_deepcopy(v) for v in value}
if isinstance(value, dict):
return {k: torch_deepcopy(v) for k, v in value.items()}
if isinstance(value, np.ndarray):
return value.copy()
if hasattr(value, "clone"):
return value.clone()
if isinstance(value, DynamicCache):
return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False))))
# We should have a code using serialization, deserialization assuming a model
# cannot be exported without them.
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
def verify_parity(
args: argparse.Namespace,
location: str,
use_auth_token: bool,
kv_cache_ortvalues: dict,
pytorch_model: None | torch.nn.Module = None,
config: None | AutoConfig = None,
):
# If it's running in a machine where GPU memory < 36GB, it should unload the model in GPU in time and free the GPU memory for ORT.
py_model = pytorch_model
if py_model is None:
config, py_model = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)
inputs = get_inputs(args, config)
if "past_key_values" in inputs and pv.Version(transformers_version) >= pv.Version("4.45"):
# Using DynamicCache
inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"])
# Run inference with PyTorch
inputs_after_deepcopy = torch_deepcopy(inputs)
if args.execution_provider != "cpu":
torch.cuda.synchronize()
start_time = time.time()
# If there is a cache in the inputs, we need to make a copy as the model modifies them inplace.
# DynamicCache inherits from torch.nn.Module in some version of transformers.
# We need to make the copy manually.
pt_outputs = py_model(**inputs_after_deepcopy).logits.detach().cpu().numpy()
if args.execution_provider != "cpu":
torch.cuda.synchronize()
end_time = time.time()
logger.info(f"PyTorch took {end_time - start_time} s")
if args.small_gpu and py_model is not None:
del py_model
torch.cuda.empty_cache()
# Run inference with ORT
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config)
inputs = convert_inputs_for_ort(
inputs,
use_buffer_share=args.use_buffer_share,
past_seq_len=past_sequence_length,
max_seq_len=max_sequence_length,
)
ep = f"{args.execution_provider.upper()}ExecutionProvider"
if ep == "CUDAExecutionProvider":
ep = (ep, {"device_id": args.rank})
ort_model = ort.InferenceSession(
args.onnx_model_path,
sess_options=ort.SessionOptions(),
providers=[ep],
)
inputs = verify_ort_inputs(ort_model, inputs)
# Add IO bindings for non-CPU execution providers
if args.execution_provider != "cpu":
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
ort_model,
ort_inputs=inputs,
device=args.execution_provider,
device_id=int(args.rank),
use_buffer_share=args.use_buffer_share,
kv_cache_ortvalues=kv_cache_ortvalues,
)
io_binding.synchronize_inputs()
start_time = time.time()
ort_model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
end_time = time.time()
ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
del ort_model
else:
start_time = time.time()
ort_outputs = ort_model.run(None, inputs)
end_time = time.time()
ort_outputs = ort_outputs[0] # Get logits
logger.info(f"ONNX Runtime took {end_time - start_time} s")
# Compare PyTorch and ONNX Runtime accuracy
tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
if not parity:
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
return kv_cache_ortvalues
def get_args(argv: list[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_name",
required=False,
help="Model name in Hugging Face",
)
parser.add_argument(
"-t",
"--torch_model_directory",
required=False,
default=os.path.join("."),
help="Path to folder containing PyTorch model and associated files if saved on disk",
)
parser.add_argument(
"-o",
"--onnx_model_path",
required=True,
default=os.path.join("."),
help="Path to ONNX model (with external data files saved in the same folder as the model)",
)
parser.add_argument(
"-ep",
"--execution_provider",
required=False,
default="cpu",
choices=["cpu", "cuda", "rocm"],
help="Execution provider to verify parity with",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Print verbose logs",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"-p",
"--use_past_kv",
action="store_true",
help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
)
parser.set_defaults(use_past_kv=False)
parser.add_argument(
"-g",
"--use_buffer_share",
action="store_true",
help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing",
)
parser.set_defaults(use_buffer_share=False)
parser.add_argument(
"--merged",
action="store_true",
help="Use merged model (i.e. decoder_merged_model.onnx).",
)
parser.set_defaults(merged=False)
parser.add_argument(
"-fp",
"--precision",
required=True,
choices=["int4", "int8", "fp16", "fp32"],
help="Precision of model",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./model_cache",
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
)
# The argument is used for CI mainly, because the CI machine has 24G GPU memory at most.
parser.add_argument(
"--small_gpu",
action="store_true",
help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ",
)
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
args.precision = (
"fp32"
if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu")
else "fp16"
)
return args
def main(argv: list[str] = []): # noqa: B006
args = get_args(argv)
setup_logger(args.verbose)
logger.info(f"Arguments: {args}")
rank = get_rank()
# Load model and config
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
args.rank = rank
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
use_auth_token = args.torch_model_directory == os.path.join(".")
location = args.model_name if use_auth_token else args.torch_model_directory
kv_cache_ortvalues = {}
if not args.merged:
verify_parity(args, location, use_auth_token, kv_cache_ortvalues)
else:
config = llama = None
if not args.small_gpu:
config, llama = setup_torch_model(
args,
location,
use_auth_token,
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
device=args.device,
)
# Verify prompt processing in merged model (decoder_model.onnx)
args.use_past_kv = False
kv_cache_ortvalues = verify_parity(
args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config
)
# Verify token generation in merged model (decoder_with_past_model.onnx)
args.use_past_kv = True
verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config)
if __name__ == "__main__":
seed = 2
np.random.seed(seed)
torch.manual_seed(seed)
main()
@@ -0,0 +1,47 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import torch
from dist_settings import barrier, get_rank, get_size
from transformers import AutoConfig, AutoModelForCausalLM
logger = logging.getLogger("")
def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=None):
world_size = get_size()
logger.info(f"world_size: {world_size}")
rank = get_rank()
barrier()
if not os.path.exists(args.cache_dir):
os.makedirs(args.cache_dir, exist_ok=True)
for i in range(world_size):
if i == rank % (world_size):
l_config = AutoConfig.from_pretrained(
location, use_auth_token=auth, cache_dir=args.cache_dir, trust_remote_code=auth
)
l_config.use_cache = True
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
llama = AutoModelForCausalLM.from_pretrained(
location,
use_auth_token=auth,
trust_remote_code=auth,
config=l_config,
torch_dtype=torch_dtype,
cache_dir=args.cache_dir,
)
if world_size > 1:
llama.parallel_model()
if device:
llama.to(device)
llama.eval()
llama.requires_grad_(False)
barrier()
return l_config, llama
@@ -0,0 +1,108 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import numpy as np
import torch
from benchmark_helper import create_onnxruntime_session
from datasets import load_dataset
from llama_inputs import get_position_ids
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from transformers import LlamaTokenizer
class QuantKVDataLoader:
def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""):
self.batch_size = 1
self.pad_max = args.pad_max
tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token)
dataset = load_dataset(args.smooth_quant_dataset, split="train")
dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
self.dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=self.collate_batch,
)
self.decoder_model = (
create_onnxruntime_session(
onnx_model_path,
args.execution_provider != "cpu", # use_gpu
provider=args.execution_provider,
verbose=args.verbose,
)
if onnx_model_path
else None
)
def collate_batch(self, batch):
input_ids_batched = []
attention_mask_batched = []
position_ids_batched = []
labels = []
for text in batch:
# Set inputs for model
input_ids = text["input_ids"]
attention_mask = torch.ones(len(input_ids))
position_ids = get_position_ids(attention_mask, use_past_kv=False)
label = len(input_ids) - 1
# Pad input data because all model inputs must have same shape
pad_len = self.pad_max - input_ids.shape[0]
input_ids = pad(input_ids, (0, pad_len), value=1)
attention_mask = pad(attention_mask, (0, pad_len), value=0)
position_ids = pad(position_ids, (0, pad_len), value=0)
input_ids_batched.append(input_ids)
attention_mask_batched.append(attention_mask)
position_ids_batched.append(position_ids)
labels.append(label)
input_ids_batched = torch.vstack(input_ids_batched)
attention_mask_batched = torch.vstack(attention_mask_batched)
position_ids_batched = torch.vstack(position_ids_batched)
labels = torch.tensor(labels)
return (input_ids_batched, attention_mask_batched, position_ids_batched), labels
def __iter__(self):
try:
for (input_ids, attention_mask, position_ids), labels in self.dataloader:
# Inputs for decoder_model.onnx
inputs = {
"input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
"attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64),
"position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
}
label = labels.detach().cpu().numpy()
if self.decoder_model is not None:
# Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx
outputs = self.decoder_model.run(None, inputs)
for i in range(int((len(outputs) - 1) / 2)):
inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1]
inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2]
past_sequence_length = inputs["past_key_values.0.key"].shape[2]
inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64)
attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64)
inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64)
inputs["position_ids"] = (
get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64)
)
# Yield (inputs, label) tuple for Intel's Neural Compressor:
# https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62
yield (inputs, label)
except StopIteration:
return
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,821 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
#
# This script run benchmark of latency or peak memory usage of Longformer model inference.
# Please run convert_to_onnx.py to get onnx model before running benchmark.
#
# It is tested with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.11.0, transformers 4.18.0, CUDA 11.3 like:
# conda create -n gpu_env python=3.8
# conda activate gpu_env
# pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# pip3 install onnx transformers onnxruntime-gpu numpy sympy coloredlogs psutil py3nvml
# python benchmark_longformer.py
#
# When there is no parameter, pre-defined tests will run on the longformer-base-4096 model.
# Benchmark the latency:
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 512 1024 2048 4096 \
# --global_lengths 8 --onnx ./longformer-base-4096_fp16.onnx -t 100
#
# Benchmark GPU peak memory:
# export ORT_LONGFORMER_COMPACT_MEMORY=0
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
# --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
# export ORT_LONGFORMER_COMPACT_MEMORY=1
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
# --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
#
# By default, compact memory kernel is enabled. To disable it, set environment variable ORT_LONGFORMER_COMPACT_MEMORY=0.
import argparse
import csv
import logging
import math
import os
import re
import sys
import timeit
import traceback
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime
from typing import Any
import benchmark_helper
import numpy as np
import torch
from longformer_helper import PRETRAINED_LONGFORMER_MODELS, LongformerHelper, LongformerInputs
from transformers import LongformerModel
import onnxruntime
logger = logging.getLogger("")
def test_torch_latency(
device,
model,
model_name,
batch_sizes,
sequence_lengths,
global_lengths,
test_times,
num_threads,
) -> list[dict[str, Any]]:
if num_threads > 0:
torch.set_num_threads(num_threads)
results = []
for batch_size in batch_sizes:
for sequence_length in sequence_lengths:
for global_length in global_lengths:
logger.info(f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}")
inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
batch_size, sequence_length, global_length, device
)
input_list = inputs.to_list()
_ = model(*input_list)
runtimes = timeit.repeat(lambda: model(*input_list), repeat=test_times, number=1) # noqa: B023
result = {
"engine": "torch", # TODO: test torchscript
"version": torch.__version__,
"device": "cuda",
"optimizer": "",
"precision": "fp32",
"io_binding": "",
"model_name": model_name,
"description": model_name + " [torch]",
"inputs": 3,
"threads": num_threads,
"batch_size": batch_size,
"sequence_length": sequence_length,
"global_length": global_length,
"datetime": str(datetime.now()),
"memory": "NA",
"diff_max": 0,
"diff_90_percentile": 0,
"diff_95_percentile": 0,
"diff_99_percentile": 0,
"use_compact_memory": "NA",
}
result.update(benchmark_helper.get_latency_result(runtimes, batch_size))
logger.info("%s", result)
results.append(result)
return results
def test_parity(device, model, ort_session, batch_size, sequence_length, global_length, verbose=True):
parameters = f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}"
logger.info(f"Comparing Torch and ORT outputs for {parameters}...")
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
batch_size, sequence_length, global_length, device
)
ort_inputs = dummy_inputs.get_ort_inputs()
ort_outputs = ort_session.run(None, ort_inputs)
input_list = dummy_inputs.to_list()
torch_outputs = model(*input_list)
max_diff = np.amax(torch_outputs[0].cpu().numpy() - ort_outputs[0])
logger.info(f"last_state max diff = {max_diff}")
if verbose and (math.isnan(max_diff) or max_diff > 0.001):
print("torch last_state:", torch_outputs[0])
print("ort last_state:", ort_outputs[0])
return float(max_diff)
def test_ort_latency(
device,
model,
model_name,
description,
ort_session,
batch_sizes,
sequence_lengths,
global_lengths,
test_times,
num_threads,
optimizer=False,
precision="fp32",
disable_io_binding=False,
verbose=True,
use_compact_memory=False,
use_half4=False,
disable_parity=False,
) -> list[dict[str, Any]]:
results = []
for batch_size in batch_sizes:
for sequence_length in sequence_lengths:
for global_length in global_lengths:
assert global_length <= model.config.attention_window[0], (
"Limitation of current implementation: number of global token <= attention_window"
)
logger.info(
f"Testing batch_size={batch_size} sequence_length={sequence_length} global_length={global_length} "
f"optimizer={optimizer}, precision={precision} io_binding={not disable_io_binding}..."
)
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
batch_size, sequence_length, global_length, device
)
# Run OnnxRuntime
ort_inputs = dummy_inputs.get_ort_inputs()
if verbose:
print(ort_inputs)
# run one query for warm up
ort_outputs = ort_session.run(None, ort_inputs)
result_template = {
"model_name": model_name,
"description": description,
"inputs": 3,
"engine": "OnnxRuntime",
"version": str(onnxruntime.__version__),
"device": "cuda",
"precision": str(precision),
"optimizer": int(optimizer),
"threads": int(num_threads),
"batch_size": int(batch_size),
"sequence_length": int(sequence_length),
"global_length": int(global_length),
"test_times": int(test_times),
"datetime": str(datetime.now()),
"memory": "",
"diff_max": None,
"diff_90_percentile": None,
"diff_95_percentile": None,
"diff_99_percentile": None,
"use_compact_memory": use_compact_memory,
"use_half4": use_half4,
}
if not disable_io_binding:
max_last_state_size = max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size
max_pooler_size = max(batch_sizes) * max(sequence_lengths)
result = benchmark_helper.inference_ort_with_io_binding(
ort_session,
ort_inputs,
result_template=result_template,
repeat_times=test_times,
ort_output_names=["last_state", "pooler"],
ort_outputs=ort_outputs,
output_buffers=[],
output_buffer_max_sizes=[max_last_state_size, max_pooler_size],
batch_size=batch_size,
device=device,
data_type=np.longlong, # input data type
)
else:
result = benchmark_helper.inference_ort(
ort_session,
ort_inputs,
result_template=result_template,
repeat_times=test_times,
batch_size=batch_size,
)
# measure result difference between PyTorch and OnnxRuntime
if not disable_parity:
diff_results = [
test_parity(
device,
model,
ort_session,
batch_size,
sequence_length,
global_length,
verbose,
)
for _ in range(test_times)
]
result["diff_max"] = max(diff_results)
result["diff_90_percentile"] = np.percentile(diff_results, 90)
result["diff_95_percentile"] = np.percentile(diff_results, 95)
result["diff_99_percentile"] = np.percentile(diff_results, 99)
results.append(result)
return results
def test_ort_memory(
device,
onnx_model_path,
batch_size,
sequence_length,
global_length,
test_times,
num_threads,
) -> dict[str, Any]:
logger.info(
f"Testing memory for model={onnx_model_path}, batch_size={batch_size}, sequence_length={sequence_length}, "
f"global_length={global_length}, test_times={test_times}, num_threads={num_threads}"
)
def inference():
# Update Arena strategy so that we can measure the minimum memory required
cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
session = benchmark_helper.create_onnxruntime_session(
onnx_model_path,
use_gpu=True,
enable_all_optimization=True,
num_threads=num_threads,
provider_options=provider_options,
)
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
batch_size, sequence_length, global_length, device
)
ort_inputs = dummy_inputs.get_ort_inputs()
for _ in range(test_times):
_ = session.run(None, ort_inputs)
memory_used = benchmark_helper.measure_memory(is_gpu=True, func=inference)
return {
"onnx_model": onnx_model_path,
"batch_size": batch_size,
"sequence_length": sequence_length,
"global_length": global_length,
"test_times": test_times,
"num_threads": num_threads,
"memory": memory_used,
}
def load_torch_model(model_name, device):
torch_model_name_or_dir = PRETRAINED_LONGFORMER_MODELS.get(model_name, model_name)
model = LongformerModel.from_pretrained(torch_model_name_or_dir)
model.to(device)
return model
def find_onnx_model(model_name, onnx_dir="."):
# Search onnx model in the following order: optimized fp16 model, optimized fp32 model, raw model
onnx_model_path = os.path.join(onnx_dir, model_name + ".onnx")
optimized_fp32_model = os.path.join(onnx_dir, model_name + "_fp32.onnx")
optimized_fp16_model = os.path.join(onnx_dir, model_name + "_fp16.onnx")
if os.path.isfile(optimized_fp16_model):
onnx_model_path = optimized_fp16_model
elif os.path.isfile(optimized_fp32_model):
onnx_model_path = optimized_fp32_model
return onnx_model_path
def test_memory(args, device) -> dict[str, Any]:
if len(args.batch_sizes) > 1:
raise RuntimeError("For memory test, only one batch_size (-b) is allowed.")
if len(args.sequence_lengths) > 1:
raise RuntimeError("For memory test, only one sequence_length (-s) is allowed.")
if len(args.global_lengths) > 1:
raise RuntimeError("For memory test, only one global_length (-g) is allowed.")
model_name = args.model
onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
torch.cuda.empty_cache()
return test_ort_memory(
device,
onnx_model_path,
args.batch_sizes[0],
args.sequence_lengths[0],
args.global_lengths[0],
args.test_times,
args.num_threads,
)
def test_ort(args, device) -> list[dict[str, Any]]:
model_name = args.model
onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
optimized = onnx_model_path.endswith("_fp16.onnx") or onnx_model_path.endswith("_fp32.onnx") # noqa: PIE810
precision = "fp32" if not onnx_model_path.endswith("_fp16.onnx") else "fp16"
model = load_torch_model(model_name, device)
num_threads = args.num_threads
cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
session = benchmark_helper.create_onnxruntime_session(
onnx_model_path,
use_gpu=True,
enable_all_optimization=True,
num_threads=num_threads,
provider_options=provider_options,
)
if session is None:
raise RuntimeError(f"Failed to create ORT session from ONNX file {onnx_model_path}")
use_compact_memory = os.environ.get("ORT_LONGFORMER_COMPACT_MEMORY", "1") == "1"
description = onnx_model_path
if not use_compact_memory:
description += "[non_compact_memory]"
if args.use_half4:
description += "[half4]" if precision == "fp16" else "[float4]"
else:
description += "[half2]" if precision == "fp16" else "[float4]"
return test_ort_latency(
device,
model,
model_name,
description,
session,
args.batch_sizes,
args.sequence_lengths,
args.global_lengths,
args.test_times,
num_threads,
optimized,
precision,
args.disable_io_binding,
args.verbose,
use_compact_memory,
args.use_half4,
args.disable_parity,
)
def test_torch(args, device) -> list[dict[str, Any]]:
model = load_torch_model(args.model, device)
return test_torch_latency(
device,
model,
args.model,
args.batch_sizes,
args.sequence_lengths,
args.global_lengths,
args.test_times,
args.num_threads,
)
def test_latency(args, device) -> list[dict[str, Any]]:
if args.engine == "onnxruntime":
return test_ort(args, device)
return test_torch(args, device)
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="longformer-base-4096",
help="Checkpoint directory or pre-trained model names in the list: "
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
)
parser.add_argument(
"-e",
"--engine",
required=False,
type=str,
default="onnxruntime",
choices=["onnxruntime", "torch"],
help="Engine to benchmark.",
)
parser.add_argument(
"-t",
"--test_times",
required=False,
default=1000,
type=int,
help="Number of repeat times to get average inference latency.",
)
parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1])
# If --export_padding is not used in exporting onnx model, there is no padding in ONNX model,
# and you will need padding inputs by yourself before running onnx model.
# Here, we only test sequence length that is multiple of attention window size.
parser.add_argument(
"-s",
"--sequence_lengths",
nargs="+",
type=int,
default=[512, 1024, 2048, 4096],
help="Sequence lengths. It could have multiple values in latency test."
"If --export_padding is not used, sequence length shall be multiple of window size.",
)
parser.add_argument("--onnx", required=False, type=str, default=None, help="Onnx model path")
parser.add_argument(
"-g",
"--global_lengths",
nargs="+",
type=int,
default=[0],
help="Number of global tokens. It could have multiple values in latency test.",
)
parser.add_argument(
"-n",
"--num_threads",
required=False,
type=int,
default=0,
help="Threads to use.",
)
parser.add_argument(
"--disable_io_binding",
required=False,
action="store_true",
help="Do not use IO Binding.",
)
parser.add_argument(
"--memory",
required=False,
action="store_true",
help="Test memory usage instead of latency.",
)
parser.add_argument("--verbose", required=False, action="store_true", help="Print more information.")
parser.set_defaults(verbose=False)
parser.add_argument("--use_half4", required=False, action="store_true", help="Use half4 kernel.")
parser.set_defaults(use_half4=False)
parser.add_argument("--disable_parity", required=False, action="store_true", help="Do not run parity test.")
parser.set_defaults(disable_parity=False)
args = parser.parse_args(argv)
return args
def output_details(results, csv_filename):
latency_results = [result for result in results if "average_latency_ms" in result]
if len(latency_results) == 0:
print("No latency results for output.")
return
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
column_names = [
"engine",
"version",
"device",
"precision",
"optimizer",
"io_binding",
"model_name",
"inputs",
"threads",
"datetime",
"test_times",
"description",
"batch_size",
"sequence_length",
"global_length",
"use_compact_memory",
"use_half4",
"diff_max",
"diff_90_percentile",
"diff_95_percentile",
"diff_99_percentile",
"memory",
"QPS",
"average_latency_ms",
"latency_variance",
"latency_90_percentile",
"latency_95_percentile",
"latency_99_percentile",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
for result in latency_results:
print(result)
csv_writer.writerow(result)
csv_file.flush()
print(f"Detail results are saved to csv file: {csv_filename}")
def run(args) -> list[dict[str, Any]]:
torch.set_grad_enabled(False)
# set random seed manually to get deterministic results
benchmark_helper.set_random_seed(123)
# Currently, the longformer attention operator could only run in GPU (no CPU implementation yet).
device = torch.device("cuda:0")
if args.memory:
return [test_memory(args, device)] # Convert to List so that return type is same as test_latency
return test_latency(args, device)
def launch_test(arguments) -> list[dict[str, Any]]:
if not torch.cuda.is_available():
raise RuntimeError("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
with ProcessPoolExecutor() as executor:
results = list(executor.map(run, [arguments]))
assert len(results) == 1
return results[0]
def run_tests(
use_compact_memory=True,
run_torch=False,
run_memory=True,
use_io_binding=True,
use_fp16=True,
use_merged_qkv_weights=True,
use_half4=True,
batch_size=1,
):
compact_memory = "1" if use_compact_memory else "0"
os.environ["ORT_LONGFORMER_COMPACT_MEMORY"] = compact_memory
logger.info(f"ORT_LONGFORMER_COMPACT_MEMORY={compact_memory}")
os.environ["ORT_LONGFORMER_USE_HALF4"] = "1" if use_half4 else "0"
logger.info("ORT_LONGFORMER_USE_HALF4={}".format("1" if use_half4 else "0")) # noqa: G001
results = []
test_times = 1000
sequence_lengths = [4096, 2048, 1024, 512]
batch_sizes = [batch_size]
for model_name in ["longformer-base-4096"]:
for batch_size in batch_sizes:
for sequence_length in sequence_lengths:
for global_length in [16]:
if run_torch:
engine_name = "torch"
args = parse_arguments(
f"-e {engine_name} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} "
f"-t {test_times} -m {model_name}".split(" ")
)
results += run(args)
engine_name = "onnxruntime"
file_format = 1 if use_merged_qkv_weights else 0
onnx_path = (
f"{model_name}_f{file_format}_fp16.onnx"
if use_fp16
else f"{model_name}_f{file_format}_fp32.onnx"
)
if not os.path.exists(onnx_path):
raise RuntimeError(f"onnx file not exists:{onnx_path}")
arguments = (
f"-e {engine_name} --onnx {onnx_path} "
f"-b {batch_size} -s {sequence_length} -g {global_length} -m {model_name}"
)
if not use_io_binding:
arguments += " --disable_io_binding"
if use_half4:
arguments += " --use_half4"
# Disable parity test to avoid out of memory for large batch size
if batch_size >= 4:
arguments += " --disable_parity"
memory_results = None
try:
if run_memory:
args = parse_arguments(f"{arguments} -t 10 --memory".split(" "))
memory_results = launch_test(args)
args = parse_arguments(f"{arguments} -t {test_times}".split(" "))
latency_results = launch_test(args)
except KeyboardInterrupt as exc:
raise RuntimeError("Keyboard Interrupted") from exc
except Exception:
traceback.print_exc()
continue
if len(latency_results) == 1:
latency_results[0]["memory"] = memory_results[0]["memory"] if memory_results else "N/A"
else:
raise RuntimeError("length of latency_results should be 1")
logger.info("%s", latency_results)
results += latency_results
return results
def output_summary(results, csv_filename, data_field="average_latency_ms"):
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
header_names = [
"model_name",
"precision",
"engine",
"version",
"global_length",
"use_compact_memory",
"use_half4",
"description",
]
description_list = list({result["description"] for result in results})
description_list.sort()
batch_sizes = list({result["batch_size"] for result in results})
batch_sizes.sort()
sequence_lengths = list({result["sequence_length"] for result in results})
sequence_lengths.sort()
data_names = []
for sequence_length in sequence_lengths:
for batch_size in batch_sizes:
data_names.append(f"b{batch_size}_s{sequence_length}")
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
csv_writer.writeheader()
for description in description_list:
row = {}
sum_latency = {}
sum_latency.update(dict.fromkeys(data_names, 0))
count_latency = {}
count_latency.update(dict.fromkeys(data_names, 0))
for result in results:
if result["description"] == description and result[data_field]:
headers = {k: v for k, v in result.items() if k in header_names}
if not row:
row.update(headers)
else:
for k in header_names:
if row[k] != headers[k]:
raise RuntimeError("Description shall be unique")
batch_size = result["batch_size"]
sequence_length = result["sequence_length"]
key = f"b{batch_size}_s{sequence_length}"
try:
latency = float(result[data_field])
except ValueError:
continue
sum_latency[key] += latency
count_latency[key] += 1
if row:
for key in data_names:
if key in count_latency and count_latency[key] > 0:
row[key] = sum_latency[key] / count_latency[key]
else:
row[key] = ""
csv_writer.writerow(row)
csv_file.flush()
def run_experiments(use_fp16, batch_size, is_baseline=False):
"""Run experiments to compare different algorithms on one batch size"""
test_results = run_tests(
use_fp16=use_fp16,
use_merged_qkv_weights=True,
use_half4=False,
batch_size=batch_size,
)
if is_baseline:
return test_results
if use_fp16:
test_results += run_tests(
use_fp16=use_fp16,
use_merged_qkv_weights=True,
use_half4=True,
batch_size=batch_size,
)
test_results += run_tests(
use_fp16=use_fp16,
use_merged_qkv_weights=False,
use_half4=True,
batch_size=batch_size,
)
test_results += run_tests(
use_fp16=use_fp16,
use_merged_qkv_weights=False,
use_half4=False,
batch_size=batch_size,
)
return test_results
def main():
torch.multiprocessing.set_start_method("spawn")
args = parse_arguments()
benchmark_helper.setup_logger(args.verbose)
if len(sys.argv) > 1:
test_results = launch_test(args)
time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
csv_filename = f"benchmark_detail_{time_stamp}.csv"
output_details(test_results, csv_filename)
return
gpu_list = benchmark_helper.get_gpu_info()
logger.info("GPU info: %s", gpu_list)
fp16_batch_sizes = [16, 8, 4, 2, 1]
fp32_batch_sizes = [4, 2, 1]
if gpu_list and gpu_list[0]["total"] >= 32 * 1024 * 1024 * 1024: # 32 GB
fp16_batch_sizes = [64, 32, 16, 8, 4, 2, 1]
fp32_batch_sizes = [16, 8, 4, 2, 1]
gpu_name = re.sub(r"(?u)[^-\w.]", "_", gpu_list[0]["name"]) if gpu_list else "gpu"
is_baseline = os.environ.get("ORT_LONGFORMER_BASELINE", "0") == "1"
experiment_name = f"longformer_base_{gpu_name}" + ("_baseline" if is_baseline else "")
logger.info(
f"experiment_name={experiment_name}, fp16_batch_sizes={fp16_batch_sizes}, fp32_batch_sizes={fp32_batch_sizes}"
)
total_runs = 1
all_results = []
for _ in range(total_runs):
for batch_size in fp16_batch_sizes:
fp16_results = run_experiments(use_fp16=True, batch_size=batch_size, is_baseline=is_baseline)
output_details(fp16_results, "longformer_base_fp16.csv")
all_results += fp16_results
for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
all_results = []
for _ in range(total_runs):
for batch_size in fp32_batch_sizes:
fp32_results = run_experiments(use_fp16=False, batch_size=batch_size, is_baseline=is_baseline)
output_details(fp32_results, "longformer_base_fp32.csv")
all_results += fp32_results
for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
if __name__ == "__main__":
main()
@@ -0,0 +1,413 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script converts Longformer model from huggingface transformers 4.0 or later to ONNX.
# It translates LongformerSelfAttention to the LongformerAttention operator in ONNX Runtime.
#
# Before running this script, prepare a python environment in Linux with PyTorch 1.9.0 and other packages installed.
# Then run "python setup.py install" in ./torch_extensions directory. If your python version is not 3.8, you will need
# update this script with correct name of longformer_attention.cpython-*.so (search TODO below).
#
# It is tested in Ubuntu 18.04 with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.9.0, transformers 4.18.0.
# Warning: Using PyTorch 1.10 or newer version might encounter issue in exporting, but they are fine for benchmarking.
#
# Example commands to export longformer base model in Linux:
# conda create -n longformer python=3.8
# conda activate longformer
# python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
# python3 -m pip install coloredlogs flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0
# python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu
# cd ./torch_extensions
# rm -rf build
# python setup.py install
# cd ..
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx --no_merge_qkv
#
# GPU is not needed for this script. You can run it in CPU. For --optimize_onnx, you can use either onnxruntime or onnxruntime-gpu package.
#
# For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or newer version.
import argparse
import inspect
from pathlib import Path
import torch
import transformers
from longformer_helper import PRETRAINED_LONGFORMER_MODELS
from onnx import load_model
from onnx_model_bert import BertOnnxModel
from packaging import version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import parse_args
from torch_onnx_export_helper import torch_onnx_export
from transformers import LongformerModel, LongformerSelfAttention
# Supports format 0 or 1
weight_bias_format = 0
@parse_args("v", "v", "v", "v", "v", "v", "v", "i", "i")
def my_longformer_attention(
g,
input,
weight,
bias,
mask,
global_weight,
global_bias,
global_mask,
num_heads,
window,
):
return g.op(
"com.microsoft::LongformerAttention",
input,
weight,
bias,
mask,
global_weight,
global_bias,
global_mask,
num_heads_i=num_heads,
window_i=window,
)
# namespace is onnxruntime which is registered in longformer_attention.cpp
register_custom_op_symbolic("onnxruntime::LongformerAttention", my_longformer_attention, 9)
# TODO: search the directory to find correct output filename of "python setup.py install" when python version is not 3.8
torch.ops.load_library(
r"./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so"
)
def parse_arguments():
"""Parse arguments
Returns:
args: Namespace
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="longformer-base-4096",
help="Checkpoint directory or pre-trained model names in the list: "
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
)
parser.add_argument(
"--export_padding",
required=False,
action="store_true",
help="Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.",
)
parser.set_defaults(export_padding=False)
parser.add_argument(
"--no_merge_qkv",
required=False,
action="store_true",
help="Stack the weights of q, k and v on dimension 0 instead of dimension 1.",
)
parser.set_defaults(no_merge_qkv=False)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model.",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument(
"-p",
"--precision",
required=False,
type=str,
default="fp32",
choices=["fp32", "fp16"],
help="Precision of model to run: fp32 for full precision, fp16 for mixed precision",
)
args = parser.parse_args()
return args
# Create a dummy input for ONNX export.
def get_dummy_inputs(config, export_padding, device):
# When sequence length is multiple of windows size, there is no padding logic in ONNX graph
sequence_length = config.attention_window[0] + 1 if export_padding else config.attention_window[0]
# Create dummy inputs
input_ids = torch.arange(sequence_length).unsqueeze(0).to(device)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
attention_mask[:, sequence_length - 1] = 0 # last token is masked
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
global_attention_mask[:, 0] = 1 # first token is global token
return input_ids, attention_mask, global_attention_mask
# A new function to replace LongformerSelfAttention.forward
# For transformers 4.0.0
def my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
):
global_mask = is_index_global_attn.int()
# The following check is based on the dummy inputs (only the first token is global).
assert (
len(global_mask.shape) == 2
and global_mask.shape[0] == 1
and global_mask.count_nonzero().item() == 1
and global_mask.tolist()[0][0] == 1
)
input_mask = is_index_masked.float()
# TODO: The filtering value may be -10000.0 or -inf. Check the huggingface implementation.
input_mask = input_mask.masked_fill(is_index_masked, -10000.0)
# Yet another way to generate input_mask = torch.masked_fill(attention_mask, is_index_global_attn, 0.0)
# TODO: add postprocessing of ONNX model to calculate based on graph input: input_mask = (attention_mask - 1) * 10000.0
# TODO: add postprocessing of ONNX model to use graph input directly: global_mask = global_attention_mask
# The following check is based on the dummy inputs (only the last token is masked).
assert (
len(input_mask.shape) == 2
and input_mask.shape[0] == 1
and input_mask.count_nonzero().item() == 1
and input_mask.tolist()[0][-1] == -10000.0
)
weight = torch.stack(
(
self.query.weight.transpose(0, 1),
self.key.weight.transpose(0, 1),
self.value.weight.transpose(0, 1),
),
dim=weight_bias_format,
)
if weight_bias_format == 1:
# shape is (hidden_size, 3*hidden_size) for format 1, otherwise (3, hidden_size, hidden_size) by default
weight = weight.reshape(self.embed_dim, 3 * self.embed_dim)
global_weight = torch.stack(
(
self.query_global.weight.transpose(0, 1),
self.key_global.weight.transpose(0, 1),
self.value_global.weight.transpose(0, 1),
),
dim=weight_bias_format,
)
if weight_bias_format == 1:
global_weight = global_weight.reshape(self.embed_dim, 3 * self.embed_dim)
if weight_bias_format == 1:
bias = torch.stack((self.query.bias, self.key.bias, self.value.bias), dim=0)
bias = bias.reshape(3 * self.embed_dim)
global_bias = torch.stack((self.query_global.bias, self.key_global.bias, self.value_global.bias), dim=0)
global_bias = global_bias.reshape(3 * self.embed_dim)
else:
bias = torch.stack(
(self.query.bias, self.key.bias, self.value.bias, self.key_global.bias, self.value_global.bias), dim=0
)
bias = bias.reshape(5 * self.embed_dim)
global_bias = self.query_global.bias
global_bias = global_bias.reshape(1 * self.embed_dim)
attn_output = torch.ops.onnxruntime.LongformerAttention(
hidden_states,
weight,
bias,
input_mask,
global_weight,
global_bias,
global_mask,
self.num_heads,
self.one_sided_attn_window_size,
)
assert attn_output.size() == hidden_states.size(), "Unexpected size"
outputs = (attn_output,)
return outputs
# For transformers 4.3.0
def my_longformer_self_attention_forward_4_3(
self,
hidden_states,
attention_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
assert output_attentions is False
return my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
)
# For transformers 4.3.2 or later versions
def my_longformer_self_attention_forward_4_3_2(
self,
hidden_states,
attention_mask=None,
layer_head_mask=None,
is_index_masked=None,
is_index_global_attn=None,
is_global_attn=None,
output_attentions=False,
):
assert output_attentions is False
assert layer_head_mask is None
return my_longformer_self_attention_forward_4(
self,
hidden_states,
attention_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
)
def export_longformer(model: LongformerModel, onnx_model_path: str, export_padding: bool):
"""Export longformer model to ONNX
Args:
model (LongformerModel): longformer model
onnx_model_path (str): output onnx path
export_padding (bool): whether export padding logic to ONNX so that input string can be any length.
Raises:
RuntimeError: This tool requires transformers 4.0.0 or later.
RuntimeError: LongformerSelfAttention.forward arguments are different.
"""
input_ids, attention_mask, global_attention_mask = get_dummy_inputs(
model.config, export_padding, device=torch.device("cpu")
)
_ = model(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
)
if version.parse(transformers.__version__) < version.parse("4.0.0"):
raise RuntimeError("This tool requires transformers 4.0.0 or later.")
# Here we replace LongformerSelfAttention.forward using our implementation for exporting ONNX model
key = " ".join(inspect.getfullargspec(LongformerSelfAttention.forward).args)
args_to_func = {
"self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3_2,
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3,
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn": my_longformer_self_attention_forward_4,
}
if key not in args_to_func:
print(
"Current arguments",
inspect.getfullargspec(LongformerSelfAttention.forward).args,
)
raise RuntimeError(
"LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)."
)
# Store for restoring later
original_forward = LongformerSelfAttention.forward
LongformerSelfAttention.forward = args_to_func[key]
example_inputs = (input_ids, attention_mask, global_attention_mask)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
model,
example_inputs,
onnx_model_path,
opset_version=12,
input_names=["input_ids", "attention_mask", "global_attention_mask"],
output_names=["last_state", "pooler"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"global_attention_mask": {0: "batch_size", 1: "sequence_length"},
"last_state": {0: "batch_size", 1: "sequence_length"},
"pooler": {0: "batch_size", 1: "sequence_length"},
},
custom_opsets={"com.microsoft": 1},
)
print(f"ONNX model exported to {onnx_model_path}")
# Restore original implementation:
LongformerSelfAttention.forward = original_forward
def optimize_longformer(onnx_model_path: str, fp32_model_path: str, fp16_model_path=None):
"""Optimize longformer onnx model
Args:
onnx_model_path (str): path of original ONNX model.
fp32_model_path (str): path of optimized fp32 model.
fp16_model_path (str, optional): path of optimized fp16 model. Defaults to None.
"""
model = load_model(onnx_model_path, format=None, load_external_data=True)
optimizer = BertOnnxModel(model)
optimizer.optimize()
use_external_data_format = False
if fp32_model_path:
optimizer.save_model_to_file(fp32_model_path, use_external_data_format)
print(f"optimized fp32 model saved to {fp32_model_path}")
if fp16_model_path:
optimizer.convert_float_to_float16(keep_io_types=True)
optimizer.save_model_to_file(fp16_model_path, use_external_data_format)
print(f"optimized fp16 model saved to {fp16_model_path}")
def main(args):
model_name = args.model
onnx_model_path = model_name + ".onnx"
global weight_bias_format # noqa: PLW0603
weight_bias_format = 0 if args.no_merge_qkv else 1
model = LongformerModel.from_pretrained(PRETRAINED_LONGFORMER_MODELS[model_name])
export_longformer(model, onnx_model_path, args.export_padding)
if args.optimize_onnx or args.precision != "fp32":
fp32_model_path = model_name + f"_f{weight_bias_format}" + "_fp32.onnx"
fp16_model_path = model_name + f"_f{weight_bias_format}" + "_fp16.onnx" if args.precision == "fp16" else None
optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path)
if __name__ == "__main__":
args = parse_arguments()
main(args)
@@ -0,0 +1,347 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Generate test data for a longformer model, so that we can use onnxruntime_perf_test.exe to evaluate the inference latency.
import argparse
import os
import random
from pathlib import Path
import numpy as np
from bert_test_data import fake_input_ids_data, fake_input_mask_data, output_test_data
from onnx import ModelProto, TensorProto
from onnx_model import OnnxModel
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
parser.add_argument(
"--output_dir",
required=False,
type=str,
default=None,
help="output test data path. If not specified, .",
)
parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
parser.add_argument(
"--sequence_length",
required=False,
type=int,
default=128,
help="maximum sequence length of input",
)
parser.add_argument(
"-a",
"--average_sequence_length",
default=-1,
type=int,
help="average sequence length excluding padding",
)
parser.add_argument(
"-r",
"--random_sequence_length",
required=False,
action="store_true",
help="use uniform random instead of fixed sequence length",
)
parser.set_defaults(random_sequence_length=False)
parser.add_argument(
"--global_tokens",
required=False,
type=int,
default=10,
help="number of global tokens",
)
parser.add_argument(
"--input_ids_name",
required=False,
type=str,
default=None,
help="input name for input ids",
)
parser.add_argument(
"--input_mask_name",
required=False,
type=str,
default=None,
help="input name for attention mask",
)
parser.add_argument(
"--global_mask_name",
required=False,
type=str,
default=None,
help="input name for global attention mask",
)
parser.add_argument(
"--samples",
required=False,
type=int,
default=1,
help="number of test cases to be generated",
)
parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
parser.add_argument(
"--verbose",
required=False,
action="store_true",
help="print verbose information",
)
parser.set_defaults(verbose=False)
args = parser.parse_args()
return args
def get_longformer_inputs(onnx_file, input_ids_name=None, input_mask_name=None, global_mask_name=None):
"""
Get graph inputs for longformer model.
"""
model = ModelProto()
with open(onnx_file, "rb") as f:
model.ParseFromString(f.read())
onnx_model = OnnxModel(model)
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
if input_ids_name is not None:
input_ids = onnx_model.find_graph_input(input_ids_name)
if input_ids is None:
raise ValueError(f"Graph does not have input named {input_ids_name}")
input_mask = None
if input_mask_name:
input_mask = onnx_model.find_graph_input(input_mask_name)
if input_mask is None:
raise ValueError(f"Graph does not have input named {input_mask_name}")
global_mask = None
if global_mask_name:
global_mask = onnx_model.find_graph_input(global_mask_name)
if global_mask is None:
raise ValueError(f"Graph does not have input named {global_mask_name}")
expected_inputs = 1 + (1 if input_mask else 0) + (1 if global_mask else 0)
if len(graph_inputs) != expected_inputs:
raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
return input_ids, input_mask, global_mask
if len(graph_inputs) != 3:
raise ValueError(f"Expect the graph to have 3 inputs. Got {len(graph_inputs)}")
# Try guess the inputs based on naming.
input_ids = None
input_mask = None
global_mask = None
for input in graph_inputs:
input_name_lower = input.name.lower()
if "global" in input_name_lower:
global_mask = input
elif "mask" in input_name_lower:
input_mask = input
else:
input_ids = input
if input_ids and input_mask and global_mask:
return input_ids, input_mask, global_mask
raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
def fake_global_mask_data(global_mask, batch_size, sequence_length, num_global_tokens):
"""
Fake data based on the graph input of segment_ids.
Args:
segment_ids (TensorProto): graph input of input tensor.
Returns:
data (np.array): the data for input tensor
"""
data_type = global_mask.type.tensor_type.elem_type
assert data_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
if num_global_tokens > 0:
assert num_global_tokens <= sequence_length
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
temp = np.ones((batch_size, num_global_tokens), dtype=np.int32)
data[: temp.shape[0], : temp.shape[1]] = temp
else:
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
if data_type == TensorProto.FLOAT:
data = np.float32(data)
elif data_type == TensorProto.INT64:
data = np.int64(data)
return data
def fake_test_data(
batch_size,
sequence_length,
test_cases,
dictionary_size,
verbose,
random_seed,
input_ids,
input_mask,
global_mask,
num_global_tokens,
average_sequence_length,
random_sequence_length,
):
"""
Generate fake input data for test.
"""
assert input_ids is not None
np.random.seed(random_seed)
random.seed(random_seed)
all_inputs = []
for _ in range(test_cases):
input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
inputs = {input_ids.name: input_1}
if input_mask:
inputs[input_mask.name] = fake_input_mask_data(
input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length
)
if global_mask:
inputs[global_mask.name] = fake_global_mask_data(
global_mask, batch_size, sequence_length, num_global_tokens
)
if verbose and len(all_inputs) == 0:
print("Example inputs", inputs)
all_inputs.append(inputs)
return all_inputs
def generate_test_data(
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids,
input_mask,
global_mask,
num_global_tokens,
average_sequence_length,
random_sequence_length,
):
dictionary_size = 10000
all_inputs = fake_test_data(
batch_size,
sequence_length,
test_cases,
dictionary_size,
verbose,
seed,
input_ids,
input_mask,
global_mask,
num_global_tokens,
average_sequence_length,
random_sequence_length,
)
if len(all_inputs) != test_cases:
print("Failed to create test data for test.")
return all_inputs
def create_longformer_test_data(
model,
output_dir,
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids_name,
input_mask_name,
global_mask_name,
num_global_tokens,
average_sequence_length,
random_sequence_length,
):
input_ids, input_mask, global_mask = get_longformer_inputs(model, input_ids_name, input_mask_name, global_mask_name)
all_inputs = generate_test_data(
batch_size,
sequence_length,
test_cases,
seed,
verbose,
input_ids,
input_mask,
global_mask,
num_global_tokens,
average_sequence_length,
random_sequence_length,
)
for i, inputs in enumerate(all_inputs):
output_test_data(output_dir, i, inputs)
def main():
args = parse_arguments()
output_dir = args.output_dir
if output_dir is None:
# Default output directory is a sub-directory under the directory of model.
output_dir = os.path.join(
Path(args.model).parent,
f"b{args.batch_size}_s{args.sequence_length}_g{args.global_tokens}",
)
if output_dir is not None:
# create the output directory if not existed
path = Path(output_dir)
path.mkdir(parents=True, exist_ok=True)
else:
print("Directory existed. test data files will be overwritten.")
if args.average_sequence_length <= 0:
args.average_sequence_length = args.sequence_length
create_longformer_test_data(
args.model,
output_dir,
args.batch_size,
args.sequence_length,
args.samples,
args.seed,
args.verbose,
args.input_ids_name,
args.input_mask_name,
args.global_mask_name,
args.global_tokens,
args.average_sequence_length,
)
print("Test data is saved to directory:", output_dir)
if __name__ == "__main__":
main()
@@ -0,0 +1,76 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This script helps creating dummy inputs for Longformer model.
import logging
import numpy
import torch
logger = logging.getLogger(__name__)
PRETRAINED_LONGFORMER_MODELS = {
"longformer-base-4096": "allenai/longformer-base-4096",
"longformer-large-4096": "allenai/longformer-large-4096",
"longformer-random-tiny": "patrickvonplaten/longformer-random-tiny", # A tiny model for debugging
}
class LongformerInputs:
def __init__(self, input_ids, attention_mask, global_attention_mask):
self.input_ids: torch.LongTensor = input_ids
self.attention_mask: torch.FloatTensor | torch.HalfTensor = attention_mask
self.global_attention_mask: torch.FloatTensor | torch.HalfTensor = global_attention_mask
def to_list(self) -> list:
return [v for v in [self.input_ids, self.attention_mask, self.global_attention_mask] if v is not None]
def to_tuple(self) -> tuple:
return tuple(v for v in self.to_list())
def get_ort_inputs(self) -> dict:
return {
"input_ids": numpy.ascontiguousarray(self.input_ids.cpu().numpy()),
"attention_mask": numpy.ascontiguousarray(self.attention_mask.cpu().numpy()),
"global_attention_mask": numpy.ascontiguousarray(self.global_attention_mask.cpu().numpy()),
}
class LongformerHelper:
"""A helper class for Longformer model conversion, inference and verification."""
@staticmethod
def get_dummy_inputs(
batch_size: int,
sequence_length: int,
num_global_tokens: int,
device: torch.device,
vocab_size: int = 100,
) -> LongformerInputs:
"""Create random inputs for Longformer model.
Returns torch tensors of input_ids, attention_mask and global_attention_mask tensors.
"""
input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=torch.long,
device=device,
)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
global_token_index = list(range(num_global_tokens))
global_attention_mask[:, global_token_index] = 1
return LongformerInputs(input_ids, attention_mask, global_attention_mask)
@staticmethod
def get_output_shapes(batch_size: int, sequence_length: int, hidden_size: int) -> dict[str, list[int]]:
"""Returns a dictionary with output name as key, and shape as value."""
return {
"last_state": [batch_size, sequence_length, hidden_size],
"pooler": [batch_size, sequence_length],
}
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,582 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations
import argparse
import logging
import os
from pathlib import Path
import onnx
import torch
from benchmark_helper import Precision
from fusion_options import AttentionOpType
from onnx_model import OnnxModel
from packaging import version
from transformers import AutoConfig, AutoModelForCausalLM
from onnxruntime import __version__ as ort_version
if version.parse(ort_version) < version.parse("1.22.0"):
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer as MatMulNBitsQuantizer
else:
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
class ConvertPhi2ToONNX:
def __init__(
self,
device: torch.device,
model_class: str = "microsoft/phi-2",
cache_dir: str = "./cache",
):
self.model_class = model_class
self.device = device
self.cache_dir = cache_dir
self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir)
self.phi_model = None
self.batch_size = 2
self.sequence_length = 8
self.attn_op_type = None
self.precision = None
self.block_size = 16
self.accuracy_level = None
def set_quantization_params(self, block_size: int, accuracy_level: int | None):
self.block_size = block_size
self.accuracy_level = accuracy_level
def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision):
self.attn_op_type = attn_op_type
self.precision = precision
def erase_onnx_model(self, onnx_path: str) -> None:
assert onnx_path.endswith(".onnx")
if not os.path.exists(onnx_path):
return
model = onnx.load_model(onnx_path, load_external_data=False)
onnx_data_path = None
for initializer in model.graph.initializer:
if initializer.data_location == 1 and initializer.external_data[0].key == "location":
onnx_data_path = "./" + initializer.external_data[0].value
break
logging.info(f"Erasing {onnx_path}...")
os.remove(onnx_path)
if onnx_data_path is not None:
onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path)
logging.info(f"Erasing {onnx_data_path}...")
os.remove(onnx_data_path)
def get_phi2_torch_model(self):
logging.info("Loading phi2 torch model...")
if self.phi_model is not None:
return
self.phi_model = AutoModelForCausalLM.from_pretrained(
self.model_class, trust_remote_code=True, cache_dir=self.cache_dir
)
self.phi_model.eval()
self.phi_model.to(self.device)
def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int):
input_ids = torch.randint(
low=0,
high=self.phi_config.vocab_size,
size=(batch_size, sequence_length),
dtype=torch.int64,
device=self.device,
)
self.get_phi2_torch_model()
torch_inputs = self.phi_model.prepare_inputs_for_generation(
input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"]
)
return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"]
def dynamo_export(self, onnx_path: str):
input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length)
self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
from torch._dynamo import config # noqa: PLC0415
config.capture_scalar_outputs = True
logging.info("Exporting Phi2 torch model to ONNX...")
torch.onnx.dynamo_export(
self.phi_model,
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
).save(onnx_path)
onnx.checker.check_model(onnx_path)
onnx.shape_inference.infer_shapes_path(onnx_path)
def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
from fusion_options import FusionOptions # noqa: PLC0415
from optimizer import optimize_model # noqa: PLC0415
optimization_options = FusionOptions("phi")
optimization_options.set_attention_op_type(self.attn_op_type)
optimizer = optimize_model(
onnx_path,
model_type="phi",
num_heads=self.phi_config.num_attention_heads,
hidden_size=self.phi_config.hidden_size,
opt_level=0,
optimization_options=optimization_options,
only_onnxruntime=False,
)
fused_op_count = optimizer.get_fused_operator_statistics()
if optimizer.is_fully_optimized(fused_op_count):
logging.info("Model is fully optimized.")
else:
logging.info("Model is not fully optimized.")
if self.precision == Precision.FLOAT32:
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
return
if (
self.precision == Precision.FLOAT16 or self.precision == Precision.INT4
) and self.attn_op_type != AttentionOpType.MultiHeadAttention:
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
node_block_list = (
[
"Attention_29",
"Attention_30",
"Attention_31",
]
if self.attn_op_type != AttentionOpType.PagedAttention
else []
) # TODO: temp setting for paged attention
logging.info("Converting onnx model to float16/bfloat16...")
optimizer.convert_float_to_float16(
keep_io_types=False,
node_block_list=node_block_list,
use_symbolic_shape_infer=True,
use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention,
)
logging.info("Converting onnx model to float16/bfloat16 done.")
if self.precision == Precision.FLOAT16:
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
return
else:
assert self.precision == Precision.INT4
quant = MatMulNBitsQuantizer(
model=optimizer.model,
block_size=self.block_size,
is_symmetric=True,
accuracy_level=self.accuracy_level,
)
quant.process()
quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True)
# This function currently only works for phi2 model
def convert_to_use_cuda_graph(self, in_onnx_path: str, out_onnx_path: str):
onnx_model = OnnxModel(onnx.load(in_onnx_path, load_external_data=True))
from onnx import TensorProto, helper # noqa: PLC0415
graph = onnx_model.graph()
new_inputs = []
for vi in graph.input:
if "attention_mask" in vi.name:
vi_seqlen_k = helper.make_tensor_value_info(
"seqlens_k",
elem_type=TensorProto.INT32,
shape=["batch_size"],
)
vi_total_seq_len = helper.make_tensor_value_info(
"total_sequence_length",
elem_type=TensorProto.INT32,
shape=[1],
)
new_inputs.extend([vi_seqlen_k, vi_total_seq_len])
else:
new_inputs.append(vi)
graph.ClearField("input")
graph.input.extend(new_inputs)
gqas = onnx_model.get_nodes_by_op_type("GroupQueryAttention")
gqa = gqas[0]
seqlens_path = onnx_model.match_parent_path(
gqa,
["Cast", "Sub", "ReduceSum", "Cast"],
[5, 0, 0, 0],
)
if seqlens_path is None:
raise RuntimeError("Failed to find seqlens path for GroupQueryAttention node.")
total_seq_len_path = onnx_model.match_parent_path(
gqa,
["Cast", "Gather", "Shape"],
[6, 0, 0],
)
if total_seq_len_path is None:
raise RuntimeError("Failed to find total_seq_len path for GroupQueryAttention node.")
onnx_model.remove_nodes(seqlens_path)
onnx_model.remove_nodes(total_seq_len_path)
for gqa in gqas:
gqa.input[5] = "seqlens_k"
gqa.input[6] = "total_sequence_length"
onnx_model.save(onnx_model.model, out_onnx_path, save_as_external_data=True)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--fp32_cpu",
required=False,
action="store_true",
help="Generate fp32 ONNX model for CPU",
)
parser.add_argument(
"--int4_cpu",
required=False,
action="store_true",
help="Generate int4 ONNX model for CPU",
)
parser.add_argument(
"--fp32_gpu",
required=False,
action="store_true",
help="Generate fp32 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--fp16_gpu",
required=False,
action="store_true",
help="Generate fp16 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--int4_gpu",
required=False,
action="store_true",
help="Generate int4 ONNX model for Nvidia GPUs",
)
parser.add_argument(
"--fp16_gpu_sm8x",
required=False,
action="store_true",
help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--int4_gpu_sm8x",
required=False,
action="store_true",
help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
)
parser.add_argument(
"--fp16_vllm",
required=False,
action="store_true",
help="Generate fp16 ONNX model for ORT VLLM",
)
parser.add_argument(
"--int4_vllm",
required=False,
action="store_true",
help="Generate int4 ONNX model for ORT VLLM",
)
parser.add_argument(
"--use_cuda_graph",
required=False,
action="store_true",
help="Use CUDA Graph in decoding process",
)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing ONNX models",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default="./cache",
help="The cache directory for the pytorch model",
)
parser.add_argument(
"--device_id",
required=False,
type=int,
default=0,
help="The device id for the pytorch model",
)
parser.add_argument(
"--run_example",
required=False,
action="store_true",
help="Run ORT inference example",
)
parser.add_argument(
"--run_benchmark",
required=False,
action="store_true",
help="Run ORT benchmark",
)
parser.add_argument(
"--skip_export",
required=False,
action="store_true",
help="Skip exporting ONNX model",
)
parser.add_argument(
"--output_dir",
type=str,
help="The output directory for the ONNX models",
default="phi2_onnx_models",
)
parser.add_argument(
"--block_size",
required=False,
default=16,
type=int,
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py for details.",
)
parser.add_argument(
"--int4_accuracy_level",
required=False,
type=int,
help="Accuracy level of the 4-bit quantized MatMul computation. "
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu")
converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir)
converter.set_quantization_params(args.block_size, args.int4_accuracy_level)
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
original_onnx_path = os.path.join(output_dir, "phi2_original.onnx")
if not args.skip_export:
if not os.path.exists(original_onnx_path) or args.overwrite:
converter.dynamo_export(original_onnx_path)
model_type_to_args = {
"fp32_cpu": (
AttentionOpType.MultiHeadAttention,
Precision.FLOAT32,
os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"),
),
"int4_cpu": (
AttentionOpType.MultiHeadAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"),
),
"fp32_gpu": (
AttentionOpType.Attention,
Precision.FLOAT32,
os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"),
),
"fp16_gpu": (
AttentionOpType.Attention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"),
),
"int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")),
"fp16_gpu_sm8x": (
AttentionOpType.GroupQueryAttention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"),
),
"int4_gpu_sm8x": (
AttentionOpType.GroupQueryAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"),
),
"fp16_vllm": (
AttentionOpType.PagedAttention,
Precision.FLOAT16,
os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"),
),
"int4_vllm": (
AttentionOpType.PagedAttention,
Precision.INT4,
os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"),
),
}
if not args.skip_export:
from multiprocessing import Process # noqa: PLC0415
def run_optimize_phi2_onnx(
converter: ConvertPhi2ToONNX,
original_onnx_path: str,
attention_type: AttentionOpType,
precision: Precision,
optimized_onnx_path: str,
):
converter.init_attn_type_and_precision(attention_type, precision)
converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path)
if args.use_cuda_graph:
assert args.fp16_gpu_sm8x or args.int4_gpu_sm8x
converter.convert_to_use_cuda_graph(optimized_onnx_path, optimized_onnx_path)
processes = []
if args.fp32_cpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"])
)
)
if args.int4_cpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"])
)
)
if args.fp32_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"])
)
)
if args.fp16_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"])
)
)
if args.int4_gpu:
processes.append(
Process(
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"])
)
)
if args.fp16_gpu_sm8x:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]),
)
)
if args.int4_gpu_sm8x:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]),
)
)
if args.fp16_vllm:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]),
)
)
if args.int4_vllm:
processes.append(
Process(
target=run_optimize_phi2_onnx,
args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]),
)
)
[p.start() for p in processes]
[p.join() for p in processes]
if args.run_example or args.run_benchmark:
from inference_example import run_phi2 # noqa: PLC0415
if args.fp16_gpu_sm8x:
logging.info("Running fp16_gpu_sm8x example...")
run_phi2(
onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2],
use_buffer_share=True,
device_id=args.device_id,
use_step=True,
use_cuda_graph=args.use_cuda_graph,
run_benchmark=args.run_benchmark,
)
if args.int4_gpu_sm8x:
logging.info("Running int4_gpu_sm8x example...")
run_phi2(
onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2],
use_buffer_share=True,
device_id=args.device_id,
use_step=True,
use_cuda_graph=args.use_cuda_graph,
run_benchmark=args.run_benchmark,
)
if args.fp32_gpu:
logging.info("Running fp32_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["fp32_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
use_fp16=False,
run_benchmark=args.run_benchmark,
)
if args.fp16_gpu:
logging.info("Running fp16_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["fp16_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
run_benchmark=args.run_benchmark,
)
if args.int4_gpu:
logging.info("Running int4_gpu example...")
run_phi2(
onnx_model_path=model_type_to_args["int4_gpu"][2],
use_buffer_share=False,
device_id=args.device_id,
packed_kv=True,
run_benchmark=args.run_benchmark,
)
if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm:
raise NotImplementedError("CPU/vllm inference example is not implemented yet.")
if __name__ == "__main__":
main()
@@ -0,0 +1,414 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import time
import numpy as np
import torch
from transformers import AutoTokenizer
import onnxruntime as ort
pt_to_np = {
"torch.int32": np.int32,
"torch.int64": np.int64,
"torch.float32": np.float32,
"torch.float16": np.float16,
}
def cuda_memcpy(dst, src):
from cuda import cudart # noqa: PLC0415
cudart.cudaMemcpy(
dst.data_ptr(),
src.data_ptr(),
src.element_size() * src.nelement(),
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
)
class ORTGenerator:
def __init__(self, decoder_path):
self.onnx_decoder_path = decoder_path
self.num_heads = 32
self.head_size = 80
self.num_layers = 32
self.max_sequence_length = 2048
self.device_id = 0
self.use_cuda_graph = False
self.use_traced_inputs = False
self.static_inputs_map = {}
def append_static_inputs(self, batch_size):
# Only use this function with GQA and with use_cuda_graph=True
if batch_size in self.static_inputs_map:
return
cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda", self.device_id)
static_io = {}
static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
for i in range(self.num_layers):
cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
self.static_inputs_map[batch_size] = static_io
def get_initial_inputs_and_outputs(self, encodings_dict):
self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
batch_size, sequence_length = input_ids.shape
self.use_traced_inputs = (
self.use_cuda_graph
and (batch_size in self.static_inputs_map)
and self.use_buffer_share
and not self.packed_kv
)
step = (
torch.tensor([0], device=self.device, dtype=torch.int64)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["step"]
)
seqlens_k = (
torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["seqlens_k"]
)
cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
total_seq_length = (
torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
if not self.use_traced_inputs
else self.static_inputs_map[batch_size]["total_sequence_length"]
)
total_seq_length[0] = sequence_length
inputs = {
"input_ids": input_ids.contiguous(),
"attention_mask": attention_mask.contiguous(),
}
if self.use_step:
inputs["step"] = step.contiguous()
if self.use_cuda_graph:
inputs["seqlens_k"] = seqlens_k.contiguous()
inputs["total_sequence_length"] = total_seq_length.contiguous()
del inputs["attention_mask"]
past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
past_shape = (
(2, batch_size, self.num_heads, past_seq_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, past_seq_length, self.head_size)
)
if not self.use_traced_inputs:
for i in range(self.num_layers):
past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
(
inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
if not self.packed_kv
else inputs.update({f"past_{i}": past.contiguous()})
)
else:
for i in range(self.num_layers):
inputs.update(
{
f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
}
)
logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
outputs = {"logits": logits.contiguous()}
if not self.use_buffer_share:
present_shape = (
(2, batch_size, self.num_heads, sequence_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, sequence_length, self.head_size)
)
for i in range(self.num_layers):
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
(
outputs.update(
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
)
if not self.packed_kv
else outputs.update({f"present_{i}": present.contiguous()})
)
return inputs, outputs
def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
io_binding = model.io_binding()
device = None
for k, v in inputs.items():
io_binding.bind_input(
name=k,
device_type=v.device.type,
device_id=0 if v.device.type == "cpu" else v.device.index,
element_type=pt_to_np[repr(v.dtype)],
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
device = v.device
for output in model.get_outputs():
name = output.name
if self.use_buffer_share and "present" in name:
v = inputs[name.replace("present", "past")]
io_binding.bind_output(
name=name,
device_type=v.device.type,
device_id=v.device.index,
element_type=(np.float16 if self.use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
else:
v = outputs[name]
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if self.use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
return io_binding
def create_session(
self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
):
self.device_id = device_id
sess_options = ort.SessionOptions()
sess_options.log_verbosity_level = 4
sess_options.log_severity_level = 4
self.use_cuda_graph = use_cuda_graph
ep = (
("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
if self.device_id >= 0
else "CPUExecutionProvider"
)
self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
self.ro = ort.RunOptions()
self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
self.use_fp16 = use_fp16
self.use_buffer_share = use_buffer_share
self.packed_kv = packed_kv
self.use_step = use_step
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
self.tokenizer.pad_token = "[PAD]"
def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
all_token_ids = inputs["input_ids"].clone()
batch_size, sequence_length = all_token_ids.shape
current_length = sequence_length
has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
if benchmark:
latency = []
prompt_run = True
while current_length < max_length:
io_binding = self.apply_io_binding(self.sess, inputs, outputs)
if benchmark:
start = time.time()
io_binding.synchronize_inputs()
if prompt_run:
if self.use_cuda_graph:
# Disable CUDA graph for the prompt run
self.ro.add_run_config_entry("gpu_graph_id", "-1")
self.sess.run_with_iobinding(io_binding, self.ro)
if self.use_cuda_graph:
# Enable CUDA graph for the decoding run
self.ro.add_run_config_entry(
"gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
)
prompt_run = False
else:
self.sess.run_with_iobinding(io_binding, self.ro)
io_binding.synchronize_outputs()
if benchmark:
end = time.time()
latency.append(end - start)
# Sample with argmax (greedy search)
next_token_logits = outputs["logits"][:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Check if we previously reached EOS token id or if generated token id is EOS token id
has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
# Determine which new tokens to add to list of all token ids
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
# Return early if all batch entries have reached EOS token id
if torch.all(has_eos):
break
# Update inputs for next inference run
current_length += 1
inputs["input_ids"] = tokens_to_add.to(torch.int32)
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
if self.use_step:
inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
inputs["step"] = self.static_inputs_map[batch_size]["step"]
if self.use_cuda_graph:
previous_seqlens_k = inputs["seqlens_k"]
inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
inputs["total_sequence_length"][0] = current_length
if self.use_traced_inputs:
cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
else:
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
).to(torch.int32)
# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
if self.use_traced_inputs:
outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
outputs["logits"].zero_()
if not self.use_buffer_share:
for i in range(self.num_layers):
if not self.packed_kv:
inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
else:
inputs[f"past_{i}"] = outputs[f"present_{i}"]
new_sequence_length = inputs["attention_mask"].shape[1]
present_shape = (
(2, batch_size, self.num_heads, new_sequence_length, self.head_size)
if self.packed_kv
else (batch_size, self.num_heads, new_sequence_length, self.head_size)
)
for i in range(self.num_layers):
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
(
outputs.update(
{
f"present_key_{i}": present.contiguous(),
f"present_value_{i}": present.clone().contiguous(),
}
)
if not self.packed_kv
else outputs.update({f"present_{i}": present.contiguous()})
)
if benchmark:
print(
f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
)
print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
return
texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
return texts
def generate(self, prompt, max_length, cuda_graph_annotation):
encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
batch_size, sequence_length = prompt_shape
max_length = sequence_length + token_num
encodings_dict = {}
encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
# Warm up run
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
# Benchmark run
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
def run_phi2(
onnx_model_path,
use_buffer_share,
device_id,
packed_kv=False,
use_fp16=True,
use_step=False,
use_cuda_graph=False,
run_benchmark=False,
):
generator = ORTGenerator(onnx_model_path)
generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
def simple_run(prompt):
example_batch_size = len(prompt)
if use_cuda_graph:
generator.append_static_inputs(batch_size=example_batch_size)
texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
for i in range(len(texts)):
print("Prompt: ", prompt[i])
print("Texts: ", texts[i])
prompt = [
'''```python
def print_prime(n):
"""
Print all primes between 1 and n
"""'''
]
if not run_benchmark:
simple_run(prompt)
# Run simple benchmark. Time the decoder only.
if run_benchmark:
token_num = 32
for batch_size in [1, 2, 4, 8]:
generator.append_static_inputs(batch_size)
for sequence_length in [16, 512]:
prompt_shape = (batch_size, sequence_length)
generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,638 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Benchmark performance of SAM2 encoder with ORT or PyTorch. See benchmark_sam2.sh for usage.
"""
import argparse
import csv
import statistics
import time
from collections.abc import Mapping
from datetime import datetime
import torch
from image_decoder import SAM2ImageDecoder
from image_encoder import SAM2ImageEncoder
from sam2_utils import decoder_shape_dict, encoder_shape_dict, load_sam2_model
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from onnxruntime.transformers.io_binding_helper import CudaSession
class TestConfig:
def __init__(
self,
model_type: str,
onnx_path: str,
sam2_dir: str,
device: torch.device,
component: str = "image_encoder",
provider="CPUExecutionProvider",
torch_compile_mode="max-autotune",
batch_size: int = 1,
height: int = 1024,
width: int = 1024,
num_labels: int = 1,
num_points: int = 1,
num_masks: int = 1,
multi_mask_output: bool = False,
use_tf32: bool = True,
enable_cuda_graph: bool = False,
dtype=torch.float32,
prefer_nhwc: bool = False,
warm_up: int = 5,
enable_nvtx_profile: bool = False,
enable_ort_profile: bool = False,
enable_torch_profile: bool = False,
repeats: int = 1000,
verbose: bool = False,
):
assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
assert height >= 160 and height <= 4096
assert width >= 160 and width <= 4096
self.model_type = model_type
self.onnx_path = onnx_path
self.sam2_dir = sam2_dir
self.component = component
self.provider = provider
self.torch_compile_mode = torch_compile_mode
self.batch_size = batch_size
self.height = height
self.width = width
self.num_labels = num_labels
self.num_points = num_points
self.num_masks = num_masks
self.multi_mask_output = multi_mask_output
self.device = device
self.use_tf32 = use_tf32
self.enable_cuda_graph = enable_cuda_graph
self.dtype = dtype
self.prefer_nhwc = prefer_nhwc
self.warm_up = warm_up
self.enable_nvtx_profile = enable_nvtx_profile
self.enable_ort_profile = enable_ort_profile
self.enable_torch_profile = enable_torch_profile
self.repeats = repeats
self.verbose = verbose
if self.component == "image_encoder":
assert self.height == 1024 and self.width == 1024, "Only image size 1024x1024 is allowed for image encoder."
def __repr__(self):
return f"{vars(self)}"
def shape_dict(self) -> Mapping[str, list[int]]:
if self.component == "image_encoder":
return encoder_shape_dict(self.batch_size, self.height, self.width)
else:
return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks)
def random_inputs(self) -> Mapping[str, torch.Tensor]:
dtype = self.dtype
if self.component == "image_encoder":
return {"image": torch.randn(self.batch_size, 3, self.height, self.width, dtype=dtype, device=self.device)}
else:
return {
"image_features_0": torch.rand(1, 32, 256, 256, dtype=dtype, device=self.device),
"image_features_1": torch.rand(1, 64, 128, 128, dtype=dtype, device=self.device),
"image_embeddings": torch.rand(1, 256, 64, 64, dtype=dtype, device=self.device),
"point_coords": torch.randint(
0, 1024, (self.num_labels, self.num_points, 2), dtype=dtype, device=self.device
),
"point_labels": torch.randint(
0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device
),
"input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=dtype, device=self.device),
"has_input_masks": torch.ones(self.num_labels, dtype=dtype, device=self.device),
"original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device),
}
def create_ort_session(config: TestConfig, session_options=None) -> InferenceSession:
if config.verbose:
print(f"create session for {vars(config)}")
if config.provider == "CUDAExecutionProvider":
device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
provider_options["use_tf32"] = int(config.use_tf32)
if config.prefer_nhwc:
provider_options["prefer_nhwc"] = 1
providers = [(config.provider, provider_options), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
ort_session = InferenceSession(config.onnx_path, session_options, providers=providers)
return ort_session
def create_session(config: TestConfig, session_options=None) -> CudaSession:
ort_session = create_ort_session(config, session_options)
cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
cuda_session.allocate_buffers(config.shape_dict())
return cuda_session
class OrtTestSession:
"""A wrapper of ORT session to test relevance and performance."""
def __init__(self, config: TestConfig, session_options=None):
self.ort_session = create_session(config, session_options)
self.feed_dict = config.random_inputs()
def infer(self):
return self.ort_session.infer(self.feed_dict)
def measure_latency(cuda_session: CudaSession, input_dict):
start = time.time()
_ = cuda_session.infer(input_dict)
end = time.time()
return end - start
def run_torch(config: TestConfig):
device_type = config.device.type
is_cuda = device_type == "cuda"
# Turn on TF32 for Ampere GPUs which could help when data type is float32.
if is_cuda and torch.cuda.get_device_properties(0).major >= 8 and config.use_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
enabled_auto_cast = is_cuda and config.dtype != torch.float32
ort_inputs = config.random_inputs()
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
if config.component == "image_encoder":
if is_cuda and config.torch_compile_mode != "none":
sam2_model.image_encoder.forward = torch.compile(
sam2_model.image_encoder.forward,
mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run.
fullgraph=True,
dynamic=False,
)
image_shape = config.shape_dict()["image"]
img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
sam2_encoder = SAM2ImageEncoder(sam2_model)
if is_cuda and config.torch_compile_mode != "none":
print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")
for _ in range(config.warm_up):
_image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
if is_cuda and config.enable_nvtx_profile:
import nvtx # noqa: PLC0415
from cuda import cudart # noqa: PLC0415
cudart.cudaProfilerStart()
print("Start nvtx profiling on encoder ...")
with nvtx.annotate("one_run"):
sam2_encoder(img, enable_nvtx_profile=True)
cudart.cudaProfilerStop()
if is_cuda and config.enable_torch_profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
print("Start torch profiling on encoder ...")
with torch.profiler.record_function("encoder"):
sam2_encoder(img)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("torch_image_encoder.json")
if config.repeats == 0:
return
print(f"Start {config.repeats} runs of performance tests...")
start = time.time()
for _ in range(config.repeats):
_image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
if is_cuda:
torch.cuda.synchronize()
else:
torch_inputs = (
ort_inputs["image_features_0"],
ort_inputs["image_features_1"],
ort_inputs["image_embeddings"],
ort_inputs["point_coords"],
ort_inputs["point_labels"],
ort_inputs["input_masks"],
ort_inputs["has_input_masks"],
ort_inputs["original_image_size"],
)
sam2_decoder = SAM2ImageDecoder(
sam2_model,
multimask_output=config.multi_mask_output,
)
if is_cuda and config.torch_compile_mode != "none":
sam2_decoder.forward = torch.compile(
sam2_decoder.forward,
mode=config.torch_compile_mode,
fullgraph=True,
dynamic=False,
)
# warm up
for _ in range(config.warm_up):
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
if is_cuda and config.enable_nvtx_profile:
import nvtx # noqa: PLC0415
from cuda import cudart # noqa: PLC0415
cudart.cudaProfilerStart()
print("Start nvtx profiling on decoder...")
with nvtx.annotate("one_run"):
sam2_decoder(*torch_inputs, enable_nvtx_profile=True)
cudart.cudaProfilerStop()
if is_cuda and config.enable_torch_profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
) as prof:
print("Start torch profiling on decoder ...")
with torch.profiler.record_function("decoder"):
sam2_decoder(*torch_inputs)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("torch_image_decoder.json")
if config.repeats == 0:
return
print(f"Start {config.repeats} runs of performance tests...")
start = time.time()
for _ in range(config.repeats):
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
if is_cuda:
torch.cuda.synchronize()
end = time.time()
return (end - start) / config.repeats
def run_test(
args: argparse.Namespace,
csv_writer: csv.DictWriter | None = None,
):
use_gpu: bool = args.use_gpu
enable_cuda_graph: bool = args.use_cuda_graph
repeats: int = args.repeats
if use_gpu:
device_id = torch.cuda.current_device()
device = torch.device("cuda", device_id)
provider = "CUDAExecutionProvider"
else:
device_id = 0
device = torch.device("cpu")
enable_cuda_graph = False
provider = "CPUExecutionProvider"
dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
config = TestConfig(
model_type=args.model_type,
onnx_path=args.onnx_path,
sam2_dir=args.sam2_dir,
component=args.component,
provider=provider,
batch_size=args.batch_size,
height=args.height,
width=args.width,
device=device,
use_tf32=True,
enable_cuda_graph=enable_cuda_graph,
dtype=dtypes[args.dtype],
prefer_nhwc=args.prefer_nhwc,
repeats=args.repeats,
warm_up=args.warm_up,
enable_nvtx_profile=args.enable_nvtx_profile,
enable_ort_profile=args.enable_ort_profile,
enable_torch_profile=args.enable_torch_profile,
torch_compile_mode=args.torch_compile_mode,
verbose=False,
)
if args.engine == "ort":
sess_options = SessionOptions()
sess_options.intra_op_num_threads = args.intra_op_num_threads
if config.enable_ort_profile:
sess_options.enable_profiling = True
sess_options.log_severity_level = 4
sess_options.log_verbosity_level = 0
session = create_session(config, sess_options)
input_dict = config.random_inputs()
# warm up session
try:
for _ in range(config.warm_up):
_ = measure_latency(session, input_dict)
except Exception as e:
print(f"Failed to run {config=}. Exception: {e}")
return
if config.enable_nvtx_profile:
import nvtx # noqa: PLC0415
from cuda import cudart # noqa: PLC0415
cudart.cudaProfilerStart()
with nvtx.annotate("one_run"):
_ = session.infer(input_dict)
cudart.cudaProfilerStop()
if config.enable_ort_profile:
session.ort_session.end_profiling()
if repeats == 0:
return
latency_list = []
for _ in range(repeats):
latency = measure_latency(session, input_dict)
latency_list.append(latency)
average_latency = statistics.mean(latency_list)
del session
else: # torch
with torch.no_grad():
try:
average_latency = run_torch(config)
except Exception as e:
print(f"Failed to run {config=}. Exception: {e}")
return
if repeats == 0:
return
engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
row = {
"model_type": args.model_type,
"component": args.component,
"dtype": args.dtype,
"use_gpu": use_gpu,
"enable_cuda_graph": enable_cuda_graph,
"prefer_nhwc": config.prefer_nhwc,
"use_tf32": config.use_tf32,
"batch_size": args.batch_size,
"height": args.height,
"width": args.width,
"multi_mask_output": args.multimask_output,
"num_labels": config.num_labels,
"num_points": config.num_points,
"num_masks": config.num_masks,
"intra_op_num_threads": args.intra_op_num_threads,
"warm_up": config.warm_up,
"repeats": repeats,
"enable_nvtx_profile": args.enable_nvtx_profile,
"torch_compile_mode": args.torch_compile_mode,
"engine": engine,
"average_latency": average_latency,
}
if csv_writer is not None:
csv_writer.writerow(row)
print(f"{vars(config)}")
print(f"{row}")
def run_perf_test(args):
features = "gpu" if args.use_gpu else "cpu"
csv_filename = "benchmark_sam_{}_{}_{}.csv".format(
features,
args.engine,
datetime.now().strftime("%Y%m%d-%H%M%S"),
)
with open(csv_filename, mode="a", newline="") as csv_file:
column_names = [
"model_type",
"component",
"dtype",
"use_gpu",
"enable_cuda_graph",
"prefer_nhwc",
"use_tf32",
"batch_size",
"height",
"width",
"multi_mask_output",
"num_labels",
"num_points",
"num_masks",
"intra_op_num_threads",
"warm_up",
"repeats",
"enable_nvtx_profile",
"torch_compile_mode",
"engine",
"average_latency",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
csv_writer.writeheader()
run_test(args, csv_writer)
def _parse_arguments():
parser = argparse.ArgumentParser(description="Benchmark SMA2 for ONNX Runtime and PyTorch.")
parser.add_argument(
"--component",
required=False,
choices=["image_encoder", "image_decoder"],
default="image_encoder",
help="component to benchmark. Choices are image_encoder and image_decoder.",
)
parser.add_argument(
"--dtype", required=False, choices=["fp32", "fp16", "bf16"], default="fp32", help="Data type for inference."
)
parser.add_argument(
"--use_gpu",
required=False,
action="store_true",
help="Use GPU for inference.",
)
parser.set_defaults(use_gpu=False)
parser.add_argument(
"--use_cuda_graph",
required=False,
action="store_true",
help="Use cuda graph in onnxruntime.",
)
parser.set_defaults(use_cuda_graph=False)
parser.add_argument(
"--intra_op_num_threads",
required=False,
type=int,
choices=[0, 1, 2, 4, 8, 16],
default=0,
help="intra_op_num_threads for onnxruntime. ",
)
parser.add_argument(
"--batch_size",
required=False,
type=int,
default=1,
help="batch size",
)
parser.add_argument(
"--height",
required=False,
type=int,
default=1024,
help="image height",
)
parser.add_argument(
"--width",
required=False,
type=int,
default=1024,
help="image width",
)
parser.add_argument(
"--repeats",
required=False,
type=int,
default=1000,
help="number of repeats for performance test. Default is 1000.",
)
parser.add_argument(
"--warm_up",
required=False,
type=int,
default=5,
help="number of runs for warm up. Default is 5.",
)
parser.add_argument(
"--engine",
required=False,
type=str,
default="ort",
choices=["ort", "torch"],
help="engine for inference",
)
parser.add_argument(
"--multimask_output",
required=False,
default=False,
action="store_true",
help="Export mask_decoder or image_decoder with multimask_output",
)
parser.add_argument(
"--prefer_nhwc",
required=False,
default=False,
action="store_true",
help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider",
)
parser.add_argument(
"--enable_nvtx_profile",
required=False,
default=False,
action="store_true",
help="Enable nvtx profiling. It will add an extra run for profiling before performance test.",
)
parser.add_argument(
"--enable_ort_profile",
required=False,
default=False,
action="store_true",
help="Enable ORT profiling.",
)
parser.add_argument(
"--enable_torch_profile",
required=False,
default=False,
action="store_true",
help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.",
)
parser.add_argument(
"--model_type",
required=False,
type=str,
default="sam2_hiera_large",
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
help="sam2 model name",
)
parser.add_argument(
"--sam2_dir",
required=False,
type=str,
default="./segment-anything-2",
help="The directory of segment-anything-2 git root directory",
)
parser.add_argument(
"--onnx_path",
required=False,
type=str,
default="./sam2_onnx_models/sam2_hiera_large_image_encoder.onnx",
help="path of onnx model",
)
parser.add_argument(
"--torch_compile_mode",
required=False,
type=str,
default=None,
choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
help="torch compile mode. none will disable torch compile.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = _parse_arguments()
print(f"arguments:{args}")
if args.torch_compile_mode is None:
# image decoder will fail with compile modes other than "none".
args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"
if args.use_gpu:
assert torch.cuda.is_available()
if args.engine == "ort":
assert "CUDAExecutionProvider" in get_available_providers()
args.enable_torch_profile = False
else:
# Only support cuda profiling for now.
assert not args.enable_nvtx_profile
assert not args.enable_torch_profile
if args.enable_nvtx_profile or args.enable_torch_profile:
run_test(args)
else:
run_perf_test(args)
@@ -0,0 +1,270 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import argparse
import os
import pathlib
import sys
import torch
from image_decoder import export_decoder_onnx, test_decoder_onnx
from image_encoder import export_image_encoder_onnx, test_image_encoder_onnx
from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx
from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx
from sam2_demo import run_demo, show_all_images
from sam2_utils import load_sam2_model, sam2_onnx_path, setup_logger
def parse_arguments():
parser = argparse.ArgumentParser(description="Export SAM2 models to ONNX")
parser.add_argument(
"--model_type",
required=False,
type=str,
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
default="sam2_hiera_large",
help="The model type to export",
)
parser.add_argument(
"--components",
required=False,
nargs="+",
choices=["image_encoder", "mask_decoder", "prompt_encoder", "image_decoder"],
default=["image_encoder", "image_decoder"],
help="Type of ONNX models to export. "
"Note that image_decoder is a combination of prompt_encoder and mask_decoder",
)
parser.add_argument(
"--output_dir",
type=str,
help="The output directory for the ONNX models",
default="sam2_onnx_models",
)
parser.add_argument(
"--dynamic_batch_axes",
required=False,
default=False,
action="store_true",
help="Export image_encoder with dynamic batch axes",
)
parser.add_argument(
"--multimask_output",
required=False,
default=False,
action="store_true",
help="Export mask_decoder or image_decoder with multimask_output",
)
parser.add_argument(
"--disable_dynamic_multimask_via_stability",
required=False,
action="store_true",
help="Disable mask_decoder dynamic_multimask_via_stability, and output first mask only."
"This option will be ignored when multimask_output is True",
)
parser.add_argument(
"--sam2_dir",
required=False,
type=str,
default="./segment-anything-2",
help="The directory of segment-anything-2 git repository",
)
parser.add_argument(
"--overwrite",
required=False,
default=False,
action="store_true",
help="Overwrite onnx model file if exists.",
)
parser.add_argument(
"--demo",
required=False,
default=False,
action="store_true",
help="Run demo with the exported ONNX models.",
)
parser.add_argument(
"--optimize",
required=False,
default=False,
action="store_true",
help="Optimize onnx models",
)
parser.add_argument(
"--dtype", required=False, choices=["fp32", "fp16"], default="fp32", help="Data type for inference."
)
parser.add_argument(
"--use_gpu",
required=False,
default=False,
action="store_true",
help="Optimize onnx models for GPU",
)
parser.add_argument(
"--dynamo",
required=False,
default=False,
action="store_true",
help="Use dynamo for exporting onnx model. Only image_encoder supports dynamo right now.",
)
parser.add_argument(
"--verbose",
required=False,
default=False,
action="store_true",
help="Print verbose information",
)
args = parser.parse_args()
return args
def optimize_sam2_model(onnx_model_path, optimized_model_path, float16: bool, use_gpu: bool):
print(f"Optimizing {onnx_model_path} to {optimized_model_path} with float16={float16} and use_gpu={use_gpu}...")
# Import from source directory.
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.insert(0, transformers_dir)
from optimizer import optimize_model # noqa: PLC0415
optimized_model = optimize_model(onnx_model_path, model_type="sam2", opt_level=1, use_gpu=use_gpu)
if float16:
optimized_model.convert_float_to_float16(keep_io_types=False)
optimized_model.save_model_to_file(optimized_model_path)
def main():
args = parse_arguments()
sam2_model = load_sam2_model(args.sam2_dir, args.model_type, device="cpu")
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
for component in args.components:
onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output)
if component == "image_encoder":
if args.overwrite or not os.path.exists(onnx_model_path):
export_image_encoder_onnx(
sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose, args.dynamo
)
test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=args.dynamic_batch_axes)
elif component == "mask_decoder":
if args.overwrite or not os.path.exists(onnx_model_path):
export_mask_decoder_onnx(
sam2_model,
onnx_model_path,
args.multimask_output,
not args.disable_dynamic_multimask_via_stability,
args.verbose,
)
test_mask_decoder_onnx(
sam2_model,
onnx_model_path,
args.multimask_output,
not args.disable_dynamic_multimask_via_stability,
)
elif component == "prompt_encoder":
if args.overwrite or not os.path.exists(onnx_model_path):
export_prompt_encoder_onnx(sam2_model, onnx_model_path)
test_prompt_encoder_onnx(sam2_model, onnx_model_path)
else:
assert component == "image_decoder"
if args.overwrite or not os.path.exists(onnx_model_path):
export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
suffix = ""
convert_to_fp16 = args.dtype == "fp16"
if args.optimize:
suffix = f"_{args.dtype}_" + ("gpu" if args.use_gpu else "cpu")
for component in args.components:
onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output)
optimized_model_path = sam2_onnx_path(
args.output_dir, args.model_type, component, args.multimask_output, suffix
)
optimize_sam2_model(onnx_model_path, optimized_model_path, convert_to_fp16, args.use_gpu)
if args.demo:
# Export required ONNX models for demo if not already exported.
image_encoder_onnx_path = sam2_onnx_path(
args.output_dir, args.model_type, "image_encoder", args.multimask_output
)
if not os.path.exists(image_encoder_onnx_path):
export_image_encoder_onnx(sam2_model, image_encoder_onnx_path, args.dynamic_batch_axes, args.verbose)
image_decoder_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", False)
if not os.path.exists(image_decoder_onnx_path):
export_decoder_onnx(sam2_model, image_decoder_onnx_path, False)
image_decoder_multi_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", True)
if not os.path.exists(image_decoder_multi_onnx_path):
export_decoder_onnx(sam2_model, image_decoder_multi_onnx_path, True)
dtype = torch.float32 if args.dtype == "fp32" else torch.float16
if suffix:
optimized_image_encoder_onnx_path = image_encoder_onnx_path.replace(".onnx", f"{suffix}.onnx")
if not os.path.exists(optimized_image_encoder_onnx_path):
optimize_sam2_model(
image_encoder_onnx_path, optimized_image_encoder_onnx_path, convert_to_fp16, args.use_gpu
)
optimized_image_decoder_onnx_path = image_decoder_onnx_path.replace(".onnx", f"{suffix}.onnx")
if not os.path.exists(optimized_image_decoder_onnx_path):
optimize_sam2_model(
image_decoder_onnx_path, optimized_image_decoder_onnx_path, convert_to_fp16, args.use_gpu
)
optimized_image_decoder_multi_onnx_path = image_decoder_multi_onnx_path.replace(".onnx", f"{suffix}.onnx")
if not os.path.exists(optimized_image_decoder_multi_onnx_path):
optimize_sam2_model(
image_decoder_multi_onnx_path,
optimized_image_decoder_multi_onnx_path,
convert_to_fp16,
args.use_gpu,
)
# Use optimized models to run demo.
image_encoder_onnx_path = optimized_image_encoder_onnx_path
image_decoder_onnx_path = optimized_image_decoder_onnx_path
image_decoder_multi_onnx_path = optimized_image_decoder_multi_onnx_path
ort_image_files = run_demo(
args.sam2_dir,
args.model_type,
engine="ort",
dtype=dtype,
image_encoder_onnx_path=image_encoder_onnx_path,
image_decoder_onnx_path=image_decoder_onnx_path,
image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
use_gpu=args.use_gpu,
)
print("demo output files for ONNX Runtime:", ort_image_files)
# Get results from torch engine to compare.
torch_image_files = run_demo(args.sam2_dir, args.model_type, engine="torch", dtype=dtype, use_gpu=args.use_gpu)
print("demo output files for PyTorch:", torch_image_files)
show_all_images(ort_image_files, torch_image_files, suffix)
print(f"Combined demo output: sam2_demo{suffix}.png")
if __name__ == "__main__":
setup_logger(verbose=False)
with torch.no_grad():
main()
@@ -0,0 +1,272 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings
import torch
import torch.nn.functional as F
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
from mask_decoder import SAM2MaskDecoder
from prompt_encoder import SAM2PromptEncoder
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance
from torch import nn
logger = logging.getLogger(__name__)
class SAM2ImageDecoder(nn.Module):
def __init__(
self,
sam_model: SAM2Base,
multimask_output: bool,
dynamic_multimask_via_stability: bool = True,
return_logits: bool = False,
mask_threshold: float = 0.0,
) -> None:
super().__init__()
self.prompt_encoder = SAM2PromptEncoder(sam_model)
self.mask_decoder = SAM2MaskDecoder(sam_model, multimask_output, dynamic_multimask_via_stability)
self.return_logits = return_logits
self.mask_threshold = mask_threshold
@torch.no_grad()
def forward(
self,
image_features_0: torch.Tensor,
image_features_1: torch.Tensor,
image_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
input_masks: torch.Tensor,
has_input_masks: torch.Tensor,
original_image_size: torch.Tensor,
enable_nvtx_profile: bool = False,
):
"""
Decode masks from image features and prompts. Batched images are not supported. H=W=1024.
Args:
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
coordinate in (x, y) format of the P input points in image of size 1024x1024.
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
positive (foreground), 0 means negative (background), -1 means padding,
2 (box left upper corner), 3 (box right bottom corner).
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
Typically coming from a previous iteration.
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
original_image_size(torch.Tensor): [2]. original image size H_o, W_o.
enable_nvtx_profile (bool): enable NVTX profiling.
Returns:
masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size.
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
"""
nvtx_helper = None
if enable_nvtx_profile:
from nvtx_helper import NvtxHelper # noqa: PLC0415
nvtx_helper = NvtxHelper(["prompt_encoder", "mask_decoder", "post_process"])
if nvtx_helper is not None:
nvtx_helper.start_profile("prompt_encoder", color="blue")
sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
if nvtx_helper is not None:
nvtx_helper.stop_profile("prompt_encoder")
nvtx_helper.start_profile("mask_decoder", color="red")
low_res_masks, iou_predictions = self.mask_decoder(
image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings
)
if nvtx_helper is not None:
nvtx_helper.stop_profile("mask_decoder")
nvtx_helper.start_profile("post_process", color="green")
# Interpolate the low resolution masks back to the original image size.
masks = F.interpolate(
low_res_masks,
(original_image_size[0], original_image_size[1]),
mode="bilinear",
align_corners=False, # Note that align_corners=True has less mismatches during comparing ORT and PyTorch.
)
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
if not self.return_logits:
masks = masks > self.mask_threshold
if nvtx_helper is not None:
nvtx_helper.stop_profile("post_process")
nvtx_helper.print_latency()
return masks, iou_predictions, low_res_masks
def export_decoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
multimask_output: bool = False,
verbose: bool = False,
):
batch_size = 1
image = random_sam2_input_image(batch_size)
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
sam2_decoder = SAM2ImageDecoder(
sam2_model,
multimask_output=multimask_output,
dynamic_multimask_via_stability=True,
).cpu()
num_labels = 2
num_points = 3
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
original_image_size = torch.tensor([1200, 1800], dtype=torch.int32)
example_inputs = (
image_features_0,
image_features_1,
image_embeddings,
point_coords,
point_labels,
input_masks,
has_input_masks,
original_image_size,
)
logger.info("point_coords.shape: %s", point_coords.shape)
logger.info("point_labels.shape: %s", point_labels.shape)
logger.info("input_masks.shape: %s", input_masks.shape)
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
logger.info("original_image_size.shape: %s", original_image_size.shape)
if verbose:
masks, iou_predictions, low_res_masks = sam2_decoder(*example_inputs)
logger.info("masks.shape: %s", masks.shape)
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
input_names = [
"image_features_0",
"image_features_1",
"image_embeddings",
"point_coords",
"point_labels",
"input_masks",
"has_input_masks",
"original_image_size",
]
output_names = ["masks", "iou_predictions", "low_res_masks"]
dynamic_axes = {
"point_coords": {0: "num_labels", 1: "num_points"},
"point_labels": {0: "num_labels", 1: "num_points"},
"input_masks": {0: "num_labels"},
"has_input_masks": {0: "num_labels"},
"masks": {0: "num_labels", 2: "original_image_height", 3: "original_image_width"},
"low_res_masks": {0: "num_labels"},
"iou_predictions": {0: "num_labels"},
}
with warnings.catch_warnings():
if not verbose:
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
torch.onnx.export(
sam2_decoder,
example_inputs,
onnx_model_path,
export_params=True,
opset_version=16,
do_constant_folding=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
logger.info("decoder onnx model saved to %s", onnx_model_path)
def test_decoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
multimask_output=False,
):
batch_size = 1
image = random_sam2_input_image(batch_size)
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
sam2_image_decoder = SAM2ImageDecoder(
sam2_model,
multimask_output=multimask_output,
dynamic_multimask_via_stability=True,
).cpu()
num_labels = 1
num_points = 5
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.zeros(1, dtype=torch.float)
original_image_size = torch.tensor([1500, 1500], dtype=torch.int32)
example_inputs = (
image_features_0,
image_features_1,
image_embeddings,
point_coords,
point_labels,
input_masks,
has_input_masks,
original_image_size,
)
masks, iou_predictions, low_res_masks = sam2_image_decoder(*example_inputs)
import onnxruntime # noqa: PLC0415
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
inputs = {model_inputs[i].name: example_inputs[i].numpy() for i in range(len(model_inputs))}
outputs = ort_session.run(output_names, inputs)
for i, output_name in enumerate(output_names):
logger.info(f"{output_name}.shape: %s", outputs[i].shape)
ort_masks, ort_iou_predictions, ort_low_res_masks = outputs
if (
compare_tensors_with_tolerance("masks", masks.float(), torch.tensor(ort_masks).float())
and compare_tensors_with_tolerance("iou_predictions", iou_predictions, torch.tensor(ort_iou_predictions))
and compare_tensors_with_tolerance("low_res_masks", low_res_masks, torch.tensor(ort_low_res_masks))
):
print("onnx model has been verified:", onnx_model_path)
else:
print("onnx model verification failed:", onnx_model_path)
@@ -0,0 +1,236 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings
import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
from torch import nn
import onnxruntime
logger = logging.getLogger(__name__)
class SAM2ImageEncoder(nn.Module):
def __init__(self, sam_model: SAM2Base) -> None:
super().__init__()
self.model = sam_model
self.image_encoder = sam_model.image_encoder
self.no_mem_embed = sam_model.no_mem_embed
def forward(
self,
image: torch.Tensor,
enable_nvtx_profile: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Encodes images into features.
Only supports H=W=1024. If you want to use different image sizes like 512x512,
see https://github.com/facebookresearch/segment-anything-2/issues/138.
Args:
image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
enable_nvtx_profile (bool): enable NVTX profiling.
Returns:
image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
"""
nvtx_helper = None
if enable_nvtx_profile:
from nvtx_helper import NvtxHelper # noqa: PLC0415
nvtx_helper = NvtxHelper(["image_encoder", "post_process"])
if nvtx_helper is not None:
nvtx_helper.start_profile("image_encoder")
backbone_out = self.image_encoder(image)
if nvtx_helper is not None:
nvtx_helper.stop_profile("image_encoder")
nvtx_helper.start_profile("post_process")
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
# Prepare and flatten visual features.
feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
# TODO: we should avoid this transpose since it will be transposed back to NCHW later.
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
feats = [
feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1], strict=False)
][::-1]
if nvtx_helper is not None:
nvtx_helper.stop_profile("post_process")
nvtx_helper.print_latency()
return feats[0], feats[1], feats[2]
def export_image_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
dynamic_batch_axes: bool = False,
verbose: bool = False,
dynamo: bool = False,
clear_dynamo_metadata: bool = False,
):
image = random_sam2_input_image()
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
logger.info("image.shape: %s", image.shape)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
dynamic_axes = None
if dynamic_batch_axes:
dynamic_axes = {
"image": {0: "batch_size"},
"image_features_0": {0: "batch_size"},
"image_features_1": {0: "batch_size"},
"image_embeddings": {0: "batch_size"},
}
with warnings.catch_warnings():
if not verbose:
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if not dynamo:
torch.onnx.export(
sam2_encoder,
image,
onnx_model_path,
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=["image"],
output_names=["image_features_0", "image_features_1", "image_embeddings"],
dynamic_axes=dynamic_axes,
)
else:
torch._dynamo.config.capture_scalar_outputs = True
ep = torch.export.export(
sam2_encoder,
args=(image,),
strict=False,
dynamic_shapes=[
{0: torch.export.Dim.AUTO},
],
)
onnx_program = torch.onnx.export(
ep,
(),
opset_version=17,
input_names=["image"],
output_names=["image_features_0", "image_features_1", "image_embeddings"],
dynamo=True,
)
onnx_program.optimize()
onnx_program.save(onnx_model_path + ".dynamo.onnx", external_data=False)
import onnx # noqa: PLC0415
from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper # noqa: PLC0415
onnx_model = onnx.load_model(onnx_model_path + ".dynamo.onnx", load_external_data=True)
if dynamic_batch_axes:
# Fix labels of dynamic axes since they can't be specified during Dynamo export currently
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "batch_size"
for i in range(3):
onnx_model.graph.output[i].type.tensor_type.shape.dim[0].dim_param = "batch_size"
onnx_model_helper = DynamoOnnxHelper(onnx_model)
onnx_model_helper.convert_constants_to_initializers()
if clear_dynamo_metadata:
onnx_model_helper.clear_metadata()
import os # noqa: PLC0415
if os.path.exists(onnx_model_path):
os.remove(onnx_model_path)
if os.path.exists(onnx_model_path + ".data"):
os.remove(onnx_model_path + ".data")
onnx_model_helper.model.save_model_to_file(
onnx_model_path, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True
)
print("encoder onnx model saved to", onnx_model_path)
def test_image_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
dynamic_batch_axes=False,
):
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
batch_sizes = [1, 2] if dynamic_batch_axes else [1]
for batch_size in batch_sizes:
image = random_sam2_input_image(batch_size)
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
logger.info("image.shape: %s", image.shape)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
outputs = ort_session.run(output_names, {"image": image.numpy()})
for i, output_name in enumerate(output_names):
logger.info("output %s shape %s", output_name, outputs[i].shape)
ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
# ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
if (
compare_tensors_with_tolerance(
"image_features_0",
image_features_0,
torch.tensor(ort_image_features_0),
mismatch_percentage_tolerance=1,
)
and compare_tensors_with_tolerance(
"image_features_1",
image_features_1,
torch.tensor(ort_image_features_1),
mismatch_percentage_tolerance=1,
)
and compare_tensors_with_tolerance(
"image_embeddings",
image_embeddings,
torch.tensor(ort_image_embeddings),
mismatch_percentage_tolerance=1,
)
):
print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
else:
print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")
@@ -0,0 +1,208 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings
import torch
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
from prompt_encoder import SAM2PromptEncoder
from sam2.modeling.sam2_base import SAM2Base
from torch import nn
logger = logging.getLogger(__name__)
class SAM2MaskDecoder(nn.Module):
def __init__(
self,
sam_model: SAM2Base,
multimask_output: bool,
dynamic_multimask_via_stability: bool = True,
) -> None:
super().__init__()
self.mask_decoder = sam_model.sam_mask_decoder
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
self.multimask_output = multimask_output
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
@torch.no_grad()
def forward(
self,
image_features_0: torch.Tensor,
image_features_1: torch.Tensor,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_embeddings: torch.Tensor,
dense_embeddings: torch.Tensor,
):
"""
Decode masks from image and prompt embeddings. Only support H=W=1024.
Args:
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding.
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
dense_embeddings (torch.Tensor): [L, 256, H/16, W/16]. embedding for input masks.
Returns:
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
"""
low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
repeat_image=sparse_embeddings.shape[0] > 1, # batch mode
high_res_features=[image_features_0, image_features_1],
)
if self.multimask_output:
low_res_masks = low_res_masks[:, 1:, :, :]
iou_predictions = iou_predictions[:, 1:]
elif self.dynamic_multimask_via_stability:
# When outputting a single mask, if the stability score from the current single-mask
# output (based on output token 0) falls below a threshold, we instead select from
# multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score.
low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(
low_res_masks, iou_predictions
)
else:
low_res_masks = low_res_masks[:, 0:1, :, :]
iou_predictions = iou_predictions[:, 0:1]
return low_res_masks, iou_predictions
def export_mask_decoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
multimask_output: bool,
dynamic_multimask_via_stability: bool = True,
verbose=False,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
image = random_sam2_input_image()
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
logger.info("image_features_0.shape: %s", image_features_0.shape)
logger.info("image_features_1.shape: %s", image_features_1.shape)
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
# encode an random prompt
num_labels = 2
num_points = 3
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
logger.info("image_pe.shape: %s", image_pe.shape)
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
with warnings.catch_warnings():
if not verbose:
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
torch.onnx.export(
sam2_mask_decoder,
inputs,
onnx_model_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=[
"image_features_0",
"image_features_1",
"image_embeddings",
"image_pe",
"sparse_embeddings",
"dense_embeddings",
],
output_names=["low_res_masks", "iou_predictions"],
dynamic_axes={
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
"dense_embeddings": {0: "num_labels"},
"low_res_masks": {0: "num_labels"},
"iou_predictions": {0: "num_labels"},
},
)
print("mask decoder onnx model saved to", onnx_model_path)
def test_mask_decoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
multimask_output: bool,
dynamic_multimask_via_stability: bool,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
image = random_sam2_input_image()
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
num_labels = 1
num_points = 5
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
import onnxruntime # noqa: PLC0415
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
outputs = ort_session.run(
output_names,
{
"image_features_0": image_features_0.numpy(),
"image_features_1": image_features_1.numpy(),
"image_embeddings": image_embeddings.numpy(),
"image_pe": image_pe.numpy(),
"sparse_embeddings": sparse_embeddings.numpy(),
"dense_embeddings": dense_embeddings.numpy(),
},
)
for i, output_name in enumerate(output_names):
logger.info("output %s shape: %s", output_name, outputs[i].shape)
ort_low_res_masks, ort_iou_predictions = outputs
torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4)
torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4)
print(f"onnx model has been verified: {onnx_model_path}")
@@ -0,0 +1,33 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import nvtx
from cuda import cudart
class NvtxHelper:
def __init__(self, stages):
self.stages = stages
self.events = {}
for stage in stages:
for marker in ["start", "stop"]:
self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
self.markers = {}
def start_profile(self, stage, color="blue"):
self.markers[stage] = nvtx.start_range(message=stage, color=color)
event_name = stage + "-start"
if event_name in self.events:
cudart.cudaEventRecord(self.events[event_name], 0)
def stop_profile(self, stage):
event_name = stage + "-stop"
if event_name in self.events:
cudart.cudaEventRecord(self.events[event_name], 0)
nvtx.end_range(self.markers[stage])
def print_latency(self):
for stage in self.stages:
latency = cudart.cudaEventElapsedTime(self.events[f"{stage}-start"], self.events[f"{stage}-stop"])[1]
print(f"{stage}: {latency:.2f} ms")
@@ -0,0 +1,189 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance
from torch import nn
logger = logging.getLogger(__name__)
class SAM2PromptEncoder(nn.Module):
def __init__(self, sam_model: SAM2Base):
super().__init__()
self.prompt_encoder = sam_model.sam_prompt_encoder
self.model = sam_model
@torch.no_grad()
def forward(
self,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
input_masks: torch.Tensor,
has_input_masks: torch.Tensor,
):
"""Encode prompts.
Args:
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
coordinate in (x, y) format of the P input points in image of size 1024x1024.
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
positive (foreground), 0 means negative (background), -1 means padding,
2 (box left upper corner), 3 (box right bottom corner).
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
Typically coming from a previous iteration.
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
Returns:
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
"""
sparse_embeddings = self._embed_points(point_coords, point_labels)
dense_embeddings = self._embed_masks(input_masks, has_input_masks)
image_pe = self.prompt_encoder.get_dense_pe()
return sparse_embeddings, dense_embeddings, image_pe
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
point_coords = torch.cat([point_coords, padding_point], dim=1)
point_labels = torch.cat([point_labels, padding_label], dim=1)
# Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
for i in range(self.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
return point_embedding
def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
logger.info("mask_embedding.shape: %s", mask_embedding.shape)
return mask_embedding
def export_prompt_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
num_labels = 2
num_points = 3
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
logger.info("point_coords.shape: %s", point_coords.shape)
logger.info("point_labels.shape: %s", point_labels.shape)
logger.info("input_masks.shape: %s", input_masks.shape)
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
logger.info("image_pe.shape: %s", image_pe.shape)
torch.onnx.export(
sam2_prompt_encoder,
(point_coords, point_labels, input_masks, has_input_masks),
onnx_model_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
dynamic_axes={
"point_coords": {0: "num_labels", 1: "num_points"},
"point_labels": {0: "num_labels", 1: "num_points"},
"input_masks": {0: "num_labels"},
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
"dense_embeddings": {0: "num_labels"},
},
)
print("prompt encoder onnx model saved to ", onnx_model_path)
def test_prompt_encoder_onnx(
sam2_model: SAM2Base,
onnx_model_path: str,
):
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
num_labels = 1
num_points = 5
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
has_input_masks = torch.ones(1, dtype=torch.float)
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
point_coords, point_labels, input_masks, has_input_masks
)
import onnxruntime # noqa: PLC0415
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
model_inputs = ort_session.get_inputs()
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
logger.info("input_names: %s", input_names)
model_outputs = ort_session.get_outputs()
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
logger.info("output_names: %s", output_names)
outputs = ort_session.run(
output_names,
{
"point_coords": point_coords.numpy(),
"point_labels": point_labels.numpy(),
"input_masks": input_masks.numpy(),
"has_input_masks": has_input_masks.numpy(),
},
)
for i, output_name in enumerate(output_names):
logger.info("output %s shape: %s", output_name, outputs[i].shape)
ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
if (
compare_tensors_with_tolerance(
"sparse_embeddings",
sparse_embeddings,
torch.tensor(ort_sparse_embeddings),
mismatch_percentage_tolerance=0.2,
)
and compare_tensors_with_tolerance(
"dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
)
and compare_tensors_with_tolerance(
"image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
)
):
print(f"onnx model has been verified: {onnx_model_path}")
else:
print(f"onnx model verification failed: {onnx_model_path}")
@@ -0,0 +1,321 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor
from sam2_utils import load_sam2_model
import onnxruntime
def show_mask(mask, ax, random_color=False, borders=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask = mask.astype(np.uint8)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if borders:
import cv2 # noqa: PLC0415
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# Try to smooth contours
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
)
ax.scatter(
neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
def show_masks(
image,
masks,
scores,
point_coords=None,
box_coords=None,
input_labels=None,
borders=True,
output_image_file_prefix=None,
image_files=None,
):
for i, (mask, score) in enumerate(zip(masks, scores, strict=False)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca(), borders=borders)
if point_coords is not None:
assert input_labels is not None
show_points(point_coords, input_labels, plt.gca())
if box_coords is not None:
show_box(box_coords, plt.gca())
if len(scores) > 1:
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
plt.axis("off")
if output_image_file_prefix:
filename = f"{output_image_file_prefix}_{i}.png"
if os.path.exists(filename):
os.remove(filename)
plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0)
if isinstance(image_files, list):
image_files.append(filename)
plt.show(block=False)
plt.close()
def get_predictor(
sam2_dir: str,
device: str | torch.device,
dtype: torch.dtype,
model_type="sam2_hiera_large",
engine="torch",
image_encoder_onnx_path: str = "",
image_decoder_onnx_path: str = "",
image_decoder_multi_onnx_path: str = "",
provider: str = "CUDAExecutionProvider",
):
sam2_model = load_sam2_model(sam2_dir, model_type, device=device)
if engine == "torch":
predictor = SAM2ImagePredictor(sam2_model)
else:
predictor = SAM2ImageOnnxPredictor(
sam2_model,
image_encoder_onnx_path=image_encoder_onnx_path,
image_decoder_onnx_path=image_decoder_onnx_path,
image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
provider=provider,
device=device,
onnx_dtype=dtype,
)
return predictor
def run_demo(
sam2_dir: str,
model_type: str = "sam2_hiera_large",
engine: str = "torch",
dtype: torch.dtype = torch.float32,
image_encoder_onnx_path: str = "",
image_decoder_onnx_path: str = "",
image_decoder_multi_onnx_path: str = "",
use_gpu: bool = True,
enable_batch: bool = False,
):
if use_gpu:
assert torch.cuda.is_available()
assert "CUDAExecutionProvider" in onnxruntime.get_available_providers()
provider = "CUDAExecutionProvider"
else:
provider = "CPUExecutionProvider"
device = torch.device("cuda" if use_gpu else "cpu")
if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8:
# Turn on tfloat32 for Ampere GPUs.
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
np.random.seed(3)
image = Image.open("truck.jpg")
image = np.array(image.convert("RGB"))
predictor = get_predictor(
sam2_dir,
device,
dtype,
model_type,
engine,
image_encoder_onnx_path,
image_decoder_onnx_path,
image_decoder_multi_onnx_path,
provider=provider,
)
predictor.set_image(image)
prefix = f"sam2_demo_{engine}_"
# The model returns masks, quality predictions for those masks,
# and low resolution mask logits that can be passed to the next iteration of prediction.
# With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where
# scores gives the model's own estimation of the quality of these masks.
# For ambiguous prompts such as a single point, it is recommended to use multimask_output=True
# even if only a single mask is desired;
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
image_files = []
show_masks(
image,
masks,
scores,
point_coords=input_point,
input_labels=input_label,
borders=True,
output_image_file_prefix=prefix + "multimask",
image_files=image_files,
)
# Multiple points.
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
show_masks(
image,
masks,
scores,
point_coords=input_point,
input_labels=input_label,
output_image_file_prefix=prefix + "multi_points",
image_files=image_files,
)
# Specify a window and a background point.
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, scores, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
show_masks(
image,
masks,
scores,
point_coords=input_point,
input_labels=input_label,
output_image_file_prefix=prefix + "background_point",
image_files=image_files,
)
# Take a box as input
input_box = np.array([425, 600, 700, 875])
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
show_masks(
image,
masks,
scores,
box_coords=input_box,
output_image_file_prefix=prefix + "box",
image_files=image_files,
)
# Combining points and boxes
input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
show_masks(
image,
masks,
scores,
box_coords=input_box,
point_coords=input_point,
input_labels=input_label,
output_image_file_prefix=prefix + "box_and_point",
image_files=image_files,
)
# TODO: support batched prompt inputs
if enable_batch:
input_boxes = np.array(
[
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
]
)
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.squeeze(0), plt.gca(), random_color=True)
for box in input_boxes:
show_box(box, plt.gca())
plt.axis("off")
plt.show()
plt.savefig(prefix + "batch_prompt.png")
image_files.append(prefix + "batch_prompt.png")
return image_files
def show_all_images(left_images, right_images, suffix=""):
# Show images in two rows since display screen is horizontal in most cases.
fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80))
for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images, strict=False)):
left_img = mpimg.imread(left_img_path)
right_img = mpimg.imread(right_img_path)
axes[0, i].imshow(left_img)
axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
axes[0, i].axis("off")
axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0])
axes[1, i].imshow(right_img)
axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
axes[1, i].axis("off")
axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0])
plt.tight_layout()
plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000)
plt.show()
@@ -0,0 +1,279 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import numpy as np
import torch
from PIL.Image import Image
from sam2.modeling.sam2_base import SAM2Base
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2_utils import decoder_shape_dict, encoder_shape_dict
from onnxruntime import InferenceSession
from onnxruntime.transformers.io_binding_helper import CudaSession
logger = logging.getLogger(__name__)
def create_ort_session(
onnx_path: str,
session_options=None,
provider="CUDAExecutionProvider",
enable_cuda_graph=False,
use_tf32=True,
) -> InferenceSession:
if provider == "CUDAExecutionProvider":
device_id = torch.cuda.current_device()
provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
provider_options["use_tf32"] = int(use_tf32)
providers = [(provider, provider_options), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
logger.info("Using providers: %s", providers)
return InferenceSession(onnx_path, session_options, providers=providers)
def create_session(
onnx_path: str,
session_options=None,
provider="CUDAExecutionProvider",
device: str | torch.device = "cuda",
enable_cuda_graph=False,
) -> CudaSession:
ort_session = create_ort_session(
onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True
)
cuda_session = CudaSession(ort_session, device=torch.device(device), enable_cuda_graph=enable_cuda_graph)
return cuda_session
class SAM2ImageOnnxPredictor(SAM2ImagePredictor):
def __init__(
self,
sam_model: SAM2Base,
image_encoder_onnx_path: str = "",
image_decoder_onnx_path: str = "",
image_decoder_multi_onnx_path: str = "",
provider: str = "CUDAExecutionProvider",
device: str | torch.device = "cuda",
onnx_dtype: torch.dtype = torch.float32,
mask_threshold=0.0,
max_hole_area=0.0,
max_sprinkle_area=0.0,
**kwargs,
) -> None:
"""
Uses SAM-2 to compute the image embedding for an image, and then allow mask prediction given prompts.
Arguments:
sam_model (SAM2Base): The model to use for mask prediction.
onnx_directory (str): The path of the directory that contains encoder and decoder onnx models.
onnx_dtype (torch.dtype): The data type to use for ONNX inputs.
mask_threshold (float): The threshold to convert mask logits to binary masks. Default is 0.0.
max_hole_area (float): If max_hole_area > 0, we fill small holes in up to
the maximum area of max_hole_area in low_res_masks.
max_sprinkle_area (float): If max_sprinkle_area > 0, we remove small sprinkles up to
the maximum area of max_sprinkle_area in low_res_masks.
"""
super().__init__(
sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area
)
logger.debug("self.device=%s, device=%s", self.device, device)
# This model is exported by image_encoder.py.
self.encoder_session = create_session(
image_encoder_onnx_path,
session_options=None,
provider=provider,
device=device,
enable_cuda_graph=False,
)
self.onnx_dtype = onnx_dtype
# This model is exported by image_decoder.py. It outputs only one mask.
self.decoder_session = create_session(
image_decoder_onnx_path,
session_options=None,
provider=provider,
device=device,
enable_cuda_graph=False,
)
# This model is exported by image_decoder.py. It outputs multiple (3) masks.
self.decoder_session_multi_out = create_session(
image_decoder_multi_onnx_path,
session_options=None,
provider=provider,
device=device,
enable_cuda_graph=False,
)
@torch.no_grad()
def set_image(self, image: np.ndarray | Image):
"""
Calculates the image embeddings for the provided image.
Arguments:
image (np.ndarray or PIL Image): The input image to embed in RGB format.
The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255].
"""
self.reset_predictor()
# Transform the image to the form expected by the model
if isinstance(image, np.ndarray):
# For numpy array image, we assume (HxWxC) format.
self._orig_hw = [image.shape[:2]]
elif isinstance(image, Image):
w, h = image.size
self._orig_hw = [(h, w)]
else:
raise NotImplementedError("Image format not supported")
input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device)
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
f"input_image must be of size 1x3xHxW, got {input_image.shape}"
)
# Computing image embeddings for the provided image
io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3])
self.encoder_session.allocate_buffers(io_shapes)
feed_dict = {"image": input_image.to(self.onnx_dtype).to(self.device)}
for key, value in feed_dict.items():
logger.debug(f"{key}: {value.shape}, {value.dtype}")
logger.debug(f"encoder onnx: {self.encoder_session.ort_session._model_path}")
ort_outputs = self.encoder_session.infer(feed_dict)
self._features = {
"image_embed": ort_outputs["image_embeddings"],
"high_res_feats": [ort_outputs[f"image_features_{i}"] for i in range(2)],
}
self._is_image_set = True
logging.info("Image embeddings computed.")
@torch.no_grad()
def _predict(
self,
point_coords: torch.Tensor | None,
point_labels: torch.Tensor | None,
boxes: torch.Tensor | None = None,
mask_input: torch.Tensor | None = None,
multimask_output: bool = True,
return_logits: bool = False,
img_idx: int = -1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using SAM2Transforms.
Arguments:
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form Bx1xHxW, where
for SAM, H=W=256. Masks returned by a previous iteration of the
predict method do not need further transformation.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
(torch.Tensor): An array of shape BxCxHxW, where C is the number
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
"""
assert not return_logits # onnx model is exported for returning bool masks.
if not self._is_image_set:
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
if point_coords is not None:
concat_points = (point_coords, point_labels)
else:
concat_points = None
# Embed prompts
if boxes is not None:
box_coords = boxes.reshape(-1, 2, 2)
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
box_labels = box_labels.repeat(boxes.size(0), 1)
# we merge "boxes" and "points" into a single "concat_points" input (where
# boxes are added at the beginning) to sam_prompt_encoder
if concat_points is not None:
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
concat_points = (concat_coords, concat_labels)
else:
concat_points = (box_coords, box_labels)
assert concat_points is not None
num_labels = concat_points[0].shape[0]
shape_dict = decoder_shape_dict(
original_image_height=self._orig_hw[img_idx][0],
original_image_width=self._orig_hw[img_idx][1],
num_labels=num_labels,
max_points=concat_points[0].shape[1],
num_masks=3 if multimask_output else 1,
)
if multimask_output:
decoder_session = self.decoder_session_multi_out
else:
decoder_session = self.decoder_session
decoder_session.allocate_buffers(shape_dict)
image_features_0 = self._features["high_res_feats"][0][img_idx].unsqueeze(0)
image_features_1 = self._features["high_res_feats"][1][img_idx].unsqueeze(0)
image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0)
if mask_input is None:
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=self.onnx_dtype, device=self.device)
has_input_masks = torch.zeros(num_labels, dtype=self.onnx_dtype, device=self.device)
else:
input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1)
has_input_masks = torch.ones(num_labels, dtype=self.onnx_dtype, device=self.device)
feed_dict = {
"image_embeddings": image_embeddings.contiguous().to(dtype=self.onnx_dtype).to(self.device),
"image_features_0": image_features_0.contiguous().to(dtype=self.onnx_dtype).to(self.device),
"image_features_1": image_features_1.contiguous().to(dtype=self.onnx_dtype).to(self.device),
"point_coords": concat_points[0].to(dtype=self.onnx_dtype).to(self.device),
"point_labels": concat_points[1].to(dtype=torch.int32).to(self.device),
"input_masks": input_masks.to(dtype=self.onnx_dtype).to(self.device),
"has_input_masks": has_input_masks.to(dtype=self.onnx_dtype).to(self.device),
"original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device),
}
for key, value in feed_dict.items():
logger.debug(f"{key}: {value.shape}, {value.dtype}")
logger.debug(f"decoder onnx: {self.decoder_session.ort_session._model_path}")
ort_outputs = decoder_session.infer(feed_dict)
masks = ort_outputs["masks"]
iou_predictions = ort_outputs["iou_predictions"]
low_res_masks = ort_outputs["low_res_masks"]
return torch.Tensor(masks), torch.Tensor(iou_predictions), torch.Tensor(low_res_masks)
@@ -0,0 +1,147 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import os
import sys
from collections.abc import Mapping
import torch
from sam2.build_sam import build_sam2
from sam2.modeling.sam2_base import SAM2Base
logger = logging.getLogger(__name__)
def _get_model_cfg(model_type) -> str:
assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
if model_type == "sam2_hiera_tiny":
model_cfg = "sam2_hiera_t.yaml"
elif model_type == "sam2_hiera_small":
model_cfg = "sam2_hiera_s.yaml"
elif model_type == "sam2_hiera_base_plus":
model_cfg = "sam2_hiera_b+.yaml"
else:
model_cfg = "sam2_hiera_l.yaml"
return model_cfg
def load_sam2_model(sam2_dir, model_type, device: str | torch.device = "cpu") -> SAM2Base:
checkpoints_dir = os.path.join(sam2_dir, "checkpoints")
sam2_config_dir = os.path.join(sam2_dir, "sam2_configs")
if not os.path.exists(sam2_dir):
raise FileNotFoundError(f"{sam2_dir} does not exist. Please specify --sam2_dir correctly.")
if not os.path.exists(checkpoints_dir):
raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.")
if not os.path.exists(sam2_config_dir):
raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.")
checkpoint_path = os.path.join(checkpoints_dir, f"{model_type}.pt")
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"{checkpoint_path} does not exist. Please download checkpoints under the directory.")
if sam2_dir not in sys.path:
sys.path.append(sam2_dir)
model_cfg = _get_model_cfg(model_type)
sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
return sam2_model
def sam2_onnx_path(output_dir, model_type, component, multimask_output=False, suffix=""):
if component == "image_encoder":
return os.path.join(output_dir, f"{model_type}_image_encoder{suffix}.onnx")
elif component == "mask_decoder":
return os.path.join(output_dir, f"{model_type}_mask_decoder{suffix}.onnx")
elif component == "prompt_encoder":
return os.path.join(output_dir, f"{model_type}_prompt_encoder{suffix}.onnx")
else:
assert component == "image_decoder"
return os.path.join(
output_dir, f"{model_type}_image_decoder" + ("_multi" if multimask_output else "") + f"{suffix}.onnx"
)
def encoder_shape_dict(batch_size: int, height: int, width: int) -> Mapping[str, list[int]]:
assert height == 1024 and width == 1024, "Only 1024x1024 images are supported."
return {
"image": [batch_size, 3, height, width],
"image_features_0": [batch_size, 32, height // 4, width // 4],
"image_features_1": [batch_size, 64, height // 8, width // 8],
"image_embeddings": [batch_size, 256, height // 16, width // 16],
}
def decoder_shape_dict(
original_image_height: int,
original_image_width: int,
num_labels: int = 1,
max_points: int = 16,
num_masks: int = 1,
) -> dict:
height: int = 1024
width: int = 1024
return {
"image_features_0": [1, 32, height // 4, width // 4],
"image_features_1": [1, 64, height // 8, width // 8],
"image_embeddings": [1, 256, height // 16, width // 16],
"point_coords": [num_labels, max_points, 2],
"point_labels": [num_labels, max_points],
"input_masks": [num_labels, 1, height // 4, width // 4],
"has_input_masks": [num_labels],
"original_image_size": [2],
"masks": [num_labels, num_masks, original_image_height, original_image_width],
"iou_predictions": [num_labels, num_masks],
"low_res_masks": [num_labels, num_masks, height // 4, width // 4],
}
def compare_tensors_with_tolerance(
name: str,
tensor1: torch.Tensor,
tensor2: torch.Tensor,
atol=5e-3,
rtol=1e-4,
mismatch_percentage_tolerance=0.1,
) -> bool:
assert tensor1.shape == tensor2.shape
a = tensor1.clone().float()
b = tensor2.clone().float()
differences = torch.abs(a - b)
mismatch_count = (differences > (rtol * torch.max(torch.abs(a), torch.abs(b)) + atol)).sum().item()
total_elements = a.numel()
mismatch_percentage = (mismatch_count / total_elements) * 100
passed = mismatch_percentage < mismatch_percentage_tolerance
log_func = logger.error if not passed else logger.info
log_func(
"%s: mismatched elements percentage %.2f (%d/%d). Verification %s (threshold=%.2f).",
name,
mismatch_percentage,
mismatch_count,
total_elements,
"passed" if passed else "failed",
mismatch_percentage_tolerance,
)
return passed
def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor:
image = torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32).cpu()
return image
def setup_logger(verbose=True):
if verbose:
logging.basicConfig(format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s")
logging.getLogger().setLevel(logging.INFO)
else:
logging.basicConfig(format="[%(message)s")
logging.getLogger().setLevel(logging.WARNING)
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,426 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import gc
import importlib.util
import time
from statistics import mean
import torch
from demo_utils import PipelineInfo
from diffusers import (
AutoencoderKL,
ControlNetModel,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
StableDiffusionXLControlNetPipeline,
)
from engine_builder import EngineType, get_engine_paths
from pipeline_stable_diffusion import StableDiffusionPipeline
"""
Benchmark script for SDXL-Turbo with control net for engines like PyTorch or Stable Fast.
Setup for Stable Fast (see https://github.com/chengzeyi/stable-fast/blob/main/README.md for more info):
git clone https://github.com/chengzeyi/stable-fast.git
cd stable-fast
git submodule update --init
pip3 install torch torchvision torchaudio ninja
pip3 install -e '.[dev,xformers,triton,transformers,diffusers]' -v
sudo apt install libgoogle-perftools-dev
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so
"""
def get_canny_image():
import cv2 # noqa: PLC0415
import numpy as np # noqa: PLC0415
from PIL import Image # noqa: PLC0415
# Test Image can be downloaded from https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png
image = Image.open("input_image_vermeer.png").convert("RGB")
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
return Image.fromarray(image)
def compile_stable_fast(pipeline, enable_cuda_graph=True):
from sfast.compilers.stable_diffusion_pipeline_compiler import CompilationConfig, compile # noqa: PLC0415
config = CompilationConfig.Default()
if importlib.util.find_spec("xformers") is not None:
config.enable_xformers = True
if importlib.util.find_spec("triton") is not None:
config.enable_triton = True
config.enable_cuda_graph = enable_cuda_graph
pipeline = compile(pipeline, config)
return pipeline
def compile_torch(pipeline, use_nhwc=False):
if use_nhwc:
pipeline.unet.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
if hasattr(pipeline, "controlnet"):
if use_nhwc:
pipeline.controlnet.to(memory_format=torch.channels_last)
pipeline.controlnet = torch.compile(pipeline.controlnet, mode="reduce-overhead", fullgraph=True)
return pipeline
def load_pipeline(name, engine, use_control_net=False, use_nhwc=False, enable_cuda_graph=True):
gc.collect()
torch.cuda.empty_cache()
before_memory = torch.cuda.memory_allocated()
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(name, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
if use_control_net:
assert "xl" in name
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
name,
controlnet=controlnet,
vae=vae,
scheduler=scheduler,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
else:
pipeline = DiffusionPipeline.from_pretrained(
name,
vae=vae,
scheduler=scheduler,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
pipeline.safety_checker = None
gc.collect()
after_memory = torch.cuda.memory_allocated()
print(f"Loaded model with {after_memory - before_memory} bytes allocated")
if engine == "stable_fast":
pipeline = compile_stable_fast(pipeline, enable_cuda_graph=enable_cuda_graph)
elif engine == "torch":
pipeline = compile_torch(pipeline, use_nhwc=use_nhwc)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def get_prompt():
return "little cute gremlin wearing a jacket, cinematic, vivid colors, intricate masterpiece, golden ratio, highly detailed"
def load_ort_cuda_pipeline(name, engine, use_control_net=False, enable_cuda_graph=True, work_dir="."):
version = PipelineInfo.supported_models()[name]
guidance_scale = 0.0
pipeline_info = PipelineInfo(
version,
use_vae=True,
use_fp16_vae=True,
do_classifier_free_guidance=(guidance_scale > 1.0),
controlnet=["canny"] if use_control_net else [],
)
engine_type = EngineType.ORT_CUDA if engine == "ort_cuda" else EngineType.ORT_TRT
onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(
work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type
)
pipeline = StableDiffusionPipeline(
pipeline_info,
scheduler="EulerA",
max_batch_size=32,
use_cuda_graph=enable_cuda_graph,
framework_model_dir=framework_model_dir,
output_dir=output_dir,
engine_type=engine_type,
)
pipeline.backend.build_engines(
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
device_id=torch.cuda.current_device(),
)
return pipeline
def test_ort_cuda(
pipeline,
batch_size=1,
steps=4,
control_image=None,
warmup_runs=3,
test_runs=10,
seed=123,
verbose=False,
image_height=512,
image_width=512,
):
if batch_size > 4 and pipeline.pipeline_info.version == "xl-1.0":
pipeline.backend.enable_vae_slicing()
pipeline.load_resources(image_height, image_width, batch_size)
warmup_prompt = "warm up"
for _ in range(warmup_runs):
images, _ = pipeline.run(
[warmup_prompt] * batch_size,
[""] * batch_size,
image_height=image_height,
image_width=image_width,
denoising_steps=steps,
guidance=0.0,
seed=seed,
controlnet_images=[control_image],
controlnet_scales=torch.FloatTensor([0.5]),
output_type="image",
)
assert len(images) == batch_size
generator = torch.Generator(device="cuda")
generator.manual_seed(seed)
prompt = get_prompt()
latency_list = []
images = None
for _ in range(test_runs):
torch.cuda.synchronize()
start_time = time.perf_counter()
images, _ = pipeline.run(
[prompt] * batch_size,
[""] * batch_size,
image_height=image_height,
image_width=image_width,
denoising_steps=steps,
guidance=0.0,
seed=seed,
controlnet_images=[control_image],
controlnet_scales=torch.FloatTensor([0.5]),
output_type="pil",
)
torch.cuda.synchronize()
seconds = time.perf_counter() - start_time
latency_list.append(seconds)
if verbose:
print(latency_list)
return images, latency_list
def test(pipeline, batch_size=1, steps=4, control_image=None, warmup_runs=3, test_runs=10, seed=123, verbose=False):
control_net_args = {}
if hasattr(pipeline, "controlnet"):
control_net_args = {
"image": control_image,
"controlnet_conditioning_scale": 0.5,
}
warmup_prompt = "warm up"
for _ in range(warmup_runs):
images = pipeline(
prompt=warmup_prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=0.0,
**control_net_args,
).images
assert len(images) == batch_size
generator = torch.Generator(device="cuda")
generator.manual_seed(seed)
prompt = get_prompt()
latency_list = []
images = None
for _ in range(test_runs):
torch.cuda.synchronize()
start_time = time.perf_counter()
images = pipeline(
prompt=prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=0.0,
generator=generator,
**control_net_args,
).images
torch.cuda.synchronize()
seconds = time.perf_counter() - start_time
latency_list.append(seconds)
if verbose:
print(latency_list)
return images, latency_list
def arguments():
import argparse # noqa: PLC0415
parser = argparse.ArgumentParser(description="Benchmark Stable Diffusion pipeline (optional control net for SDXL)")
parser.add_argument(
"--engine",
type=str,
default="torch",
choices=["torch", "stable_fast", "ort_cuda", "ort_trt"],
help="Backend engine: torch, stable_fast or ort_cuda",
)
parser.add_argument(
"--name",
type=str,
choices=list(PipelineInfo.supported_models().keys()),
default="stabilityai/sdxl-turbo",
help="Stable diffusion model name. Default is stabilityai/sdxl-turbo",
)
parser.add_argument(
"--work-dir",
type=str,
default=".",
help="working directory for ort_cuda or ort_trt",
)
parser.add_argument(
"--use_control_net",
action="store_true",
help="Use control net diffusers/controlnet-canny-sdxl-1.0",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size",
)
parser.add_argument(
"--steps",
type=int,
default=1,
help="Denoising steps",
)
parser.add_argument(
"--warmup_runs",
type=int,
default=3,
help="Number of warmup runs before measurement",
)
parser.add_argument(
"--use_nhwc",
action="store_true",
help="use channel last format for torch compile",
)
parser.add_argument(
"--enable_cuda_graph",
action="store_true",
help="enable cuda graph for stable fast",
)
parser.add_argument(
"--verbose",
action="store_true",
help="print more information",
)
args = parser.parse_args()
return args
def main():
args = arguments()
with torch.no_grad():
if args.engine == "ort_cuda":
pipeline = load_ort_cuda_pipeline(
args.name,
args.engine,
use_control_net=args.use_control_net,
enable_cuda_graph=args.enable_cuda_graph,
work_dir=args.work_dir,
)
else:
pipeline = load_pipeline(
args.name,
args.engine,
use_control_net=args.use_control_net,
use_nhwc=args.use_nhwc,
enable_cuda_graph=args.enable_cuda_graph,
)
canny_image = get_canny_image()
if args.engine == "ort_cuda":
images, latency_list = test_ort_cuda(
pipeline,
args.batch_size,
args.steps,
control_image=canny_image,
warmup_runs=args.warmup_runs,
verbose=args.verbose,
)
elif args.engine == "stable_fast":
from sfast.utils.compute_precision import low_compute_precision # noqa: PLC0415
with low_compute_precision():
images, latency_list = test(
pipeline,
args.batch_size,
args.steps,
control_image=canny_image,
warmup_runs=args.warmup_runs,
verbose=args.verbose,
)
else:
images, latency_list = test(
pipeline,
args.batch_size,
args.steps,
control_image=canny_image,
warmup_runs=args.warmup_runs,
verbose=args.verbose,
)
# Save the first output image to inspect the result.
if images:
images[0].save(
f"{args.engine}_{args.name.replace('/', '_')}_{args.batch_size}_{args.steps}_c{int(args.use_control_net)}.png"
)
result = {
"engine": args.engine,
"batch_size": args.batch_size,
"steps": args.steps,
"control_net": args.use_control_net,
"nhwc": args.use_nhwc,
"enable_cuda_graph": args.enable_cuda_graph,
"average_latency_in_ms": mean(latency_list) * 1000,
}
print(result)
main()
@@ -0,0 +1,102 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import coloredlogs
from cuda import cudart
from demo_utils import (
add_controlnet_arguments,
arg_parser,
get_metadata,
load_pipelines,
parse_arguments,
process_controlnet_arguments,
repeat_prompt,
)
def main(args):
controlnet_images, controlnet_scale = process_controlnet_arguments(args)
pipeline, refiner = load_pipelines(args)
assert refiner is None
prompt, negative_prompt = repeat_prompt(args)
batch_size = len(prompt)
pipeline.load_resources(args.height, args.width, batch_size)
def run_inference(warmup=False):
return pipeline.run(
prompt,
negative_prompt,
args.height,
args.width,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
controlnet_images=controlnet_images,
controlnet_scales=controlnet_scale,
show_latency=not warmup,
output_type="pil",
deterministic=args.deterministic,
)
if not args.disable_cuda_graph:
# inference once to get cuda graph
_, _ = run_inference(warmup=True)
print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
_, _ = run_inference(warmup=True)
print("[I] Running StableDiffusion pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
images, perf_data = run_inference(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()
metadata = get_metadata(args, False)
metadata.update(pipeline.metadata())
if perf_data:
metadata.update(perf_data)
metadata["images"] = len(images)
print(metadata)
pipeline.save_images(images, prompt, negative_prompt, metadata)
pipeline.teardown()
if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
parser = arg_parser("Options for Stable Diffusion Demo")
add_controlnet_arguments(parser)
args = parse_arguments(is_xl=False, parser=parser)
if args.user_compute_stream:
import torch
s = torch.cuda.Stream()
with torch.cuda.stream(s):
main(args)
else:
main(args)
@@ -0,0 +1,268 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import coloredlogs
from cuda import cudart
from demo_utils import (
add_controlnet_arguments,
arg_parser,
get_metadata,
load_pipelines,
parse_arguments,
process_controlnet_arguments,
repeat_prompt,
)
def run_pipelines(
args, base, refiner, prompt, negative_prompt, controlnet_image=None, controlnet_scale=None, is_warm_up=False
):
image_height = args.height
image_width = args.width
batch_size = len(prompt)
base.load_resources(image_height, image_width, batch_size)
if refiner:
refiner.load_resources(image_height, image_width, batch_size)
def run_base_and_refiner(warmup=False):
images, base_perf = base.run(
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=args.denoising_steps,
guidance=args.guidance,
seed=args.seed,
controlnet_images=controlnet_image,
controlnet_scales=controlnet_scale,
show_latency=not warmup,
output_type="latent" if refiner else "pil",
)
if refiner is None:
return images, base_perf
# Use same seed in base and refiner.
seed = base.get_current_seed()
images, refiner_perf = refiner.run(
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=args.refiner_denoising_steps,
image=images,
strength=args.strength,
guidance=args.refiner_guidance,
seed=seed,
show_latency=not warmup,
)
perf_data = None
if base_perf and refiner_perf:
perf_data = {"latency": base_perf["latency"] + refiner_perf["latency"]}
perf_data.update({"base." + key: val for key, val in base_perf.items()})
perf_data.update({"refiner." + key: val for key, val in refiner_perf.items()})
return images, perf_data
if not args.disable_cuda_graph:
# inference once to get cuda graph
_, _ = run_base_and_refiner(warmup=True)
if args.num_warmup_runs > 0:
print("[I] Warming up ..")
for _ in range(args.num_warmup_runs):
_, _ = run_base_and_refiner(warmup=True)
if is_warm_up:
return
print("[I] Running StableDiffusion XL pipeline")
if args.nvtx_profile:
cudart.cudaProfilerStart()
images, perf_data = run_base_and_refiner(warmup=False)
if args.nvtx_profile:
cudart.cudaProfilerStop()
if refiner:
print("|----------------|--------------|")
print("| {:^14} | {:>9.2f} ms |".format("e2e", perf_data["latency"]))
print("|----------------|--------------|")
metadata = get_metadata(args, True)
metadata.update({"base." + key: val for key, val in base.metadata().items()})
if refiner:
metadata.update({"refiner." + key: val for key, val in refiner.metadata().items()})
if perf_data:
metadata.update(perf_data)
metadata["images"] = len(images)
print(metadata)
(refiner or base).save_images(images, prompt, negative_prompt, metadata)
def run_demo(args):
"""Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""
controlnet_image, controlnet_scale = process_controlnet_arguments(args)
prompt, negative_prompt = repeat_prompt(args)
batch_size = len(prompt)
base, refiner = load_pipelines(args, batch_size)
run_pipelines(args, base, refiner, prompt, negative_prompt, controlnet_image, controlnet_scale)
base.teardown()
if refiner:
refiner.teardown()
def run_dynamic_shape_demo(args):
"""
Run demo of generating images with different settings with ORT CUDA provider.
Try "python demo_txt2img_xl.py --max-cuda-graphs 3 --user-compute-stream" to see the effect of multiple CUDA graphs.
"""
args.engine = "ORT_CUDA"
base, refiner = load_pipelines(args, 1)
prompts = [
"starry night over Golden Gate Bridge by van gogh",
"beautiful photograph of Mt. Fuji during cherry blossom",
"little cute gremlin sitting on a bed, cinematic",
"cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
"beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation",
"blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
"An astronaut riding a rainbow unicorn, cinematic, dramatic",
"close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm",
]
# batch size, height, width, scheduler, steps, prompt, seed, guidance, refiner scheduler, refiner steps, refiner strength
configs = [
(1, 832, 1216, "UniPC", 8, prompts[0], None, 5.0, "UniPC", 10, 0.3),
(1, 1024, 1024, "DDIM", 24, prompts[1], None, 5.0, "DDIM", 30, 0.3),
(1, 1216, 832, "EulerA", 16, prompts[2], 1716921396712843, 5.0, "EulerA", 10, 0.3),
(1, 1344, 768, "EulerA", 24, prompts[3], 123698071912362, 5.0, "EulerA", 20, 0.3),
(2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712, 5.0, "UniPC", 10, 0.3),
(2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906, 5.0, "UniPC", 20, 0.3),
]
# In testing LCM, refiner is disabled so the settings of refiner is not used.
if args.lcm:
configs = [
(1, 1024, 1024, "LCM", 8, prompts[6], None, 1.0, "UniPC", 20, 0.3),
(1, 1216, 832, "LCM", 6, prompts[7], 1337, 1.0, "UniPC", 20, 0.3),
]
# Warm up each combination of (batch size, height, width) once before serving.
args.prompt = ["warm up"]
args.num_warmup_runs = 1
for batch_size, height, width, _, _, _, _, _, _, _, _ in configs:
args.batch_size = batch_size
args.height = height
args.width = width
print(f"\nWarm up batch_size={batch_size}, height={height}, width={width}")
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=True)
# Run pipeline on a list of prompts.
args.num_warmup_runs = 0
for (
batch_size,
height,
width,
scheduler,
steps,
example_prompt,
seed,
guidance,
refiner_scheduler,
refiner_denoising_steps,
strength,
) in configs:
args.prompt = [example_prompt]
args.batch_size = batch_size
args.height = height
args.width = width
args.scheduler = scheduler
args.denoising_steps = steps
args.seed = seed
args.guidance = guidance
args.refiner_scheduler = refiner_scheduler
args.refiner_denoising_steps = refiner_denoising_steps
args.strength = strength
base.set_scheduler(scheduler)
if refiner:
refiner.set_scheduler(refiner_scheduler)
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)
base.teardown()
if refiner:
refiner.teardown()
def run_turbo_demo(args):
"""Run demo of generating images with test prompts with ORT CUDA provider."""
args.engine = "ORT_CUDA"
base, refiner = load_pipelines(args, 1)
from datasets import load_dataset # noqa: PLC0415
dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
num_rows = dataset["test"].num_rows
batch_size = args.batch_size
num_batch = int(num_rows / batch_size)
args.batch_size = 1
for i in range(num_batch):
args.prompt = [dataset["test"][i]["Prompt"] for i in range(i * batch_size, (i + 1) * batch_size)]
base.set_scheduler(args.scheduler)
if refiner:
refiner.set_scheduler(args.refiner_scheduler)
prompt, negative_prompt = repeat_prompt(args)
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)
base.teardown()
if refiner:
refiner.teardown()
def main(args):
no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0]
if no_prompt:
if args.version == "xl-turbo":
run_turbo_demo(args)
else:
run_dynamic_shape_demo(args)
else:
run_demo(args)
if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
parser = arg_parser("Options for Stable Diffusion XL Demo")
add_controlnet_arguments(parser)
args = parse_arguments(is_xl=True, parser=parser)
if args.user_compute_stream:
import torch
s = torch.cuda.Stream()
with torch.cuda.stream(s):
main(args)
else:
main(args)
@@ -0,0 +1,778 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import argparse
import os
import sys
from importlib.metadata import PackageNotFoundError, version
from typing import Any
import controlnet_aux
import cv2
import numpy as np
import torch
from cuda import cudart
from diffusion_models import PipelineInfo
from engine_builder import EngineType, get_engine_paths, get_engine_type
from PIL import Image
from pipeline_stable_diffusion import StableDiffusionPipeline
class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
pass
def arg_parser(description: str):
return argparse.ArgumentParser(
description=description,
formatter_class=RawTextArgumentDefaultsHelpFormatter,
)
def set_default_arguments(args):
# set default value for some arguments if not provided
if args.height is None:
args.height = PipelineInfo.default_resolution(args.version)
if args.width is None:
args.width = PipelineInfo.default_resolution(args.version)
is_lcm = (args.version == "xl-1.0" and args.lcm) or "lcm" in args.lora_weights
is_turbo = args.version in ["sd-turbo", "xl-turbo"]
if args.denoising_steps is None:
args.denoising_steps = 4 if is_turbo else 8 if is_lcm else (30 if args.version == "xl-1.0" else 50)
if args.scheduler is None:
args.scheduler = "LCM" if (is_lcm or is_turbo) else ("EulerA" if args.version == "xl-1.0" else "DDIM")
if args.guidance is None:
args.guidance = 0.0 if (is_lcm or is_turbo) else (5.0 if args.version == "xl-1.0" else 7.5)
def parse_arguments(is_xl: bool, parser):
engines = ["ORT_CUDA", "ORT_TRT", "TRT", "TORCH"]
parser.add_argument(
"-e",
"--engine",
type=str,
default=engines[0],
choices=engines,
help="Backend engine in {engines}. "
"ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT",
)
supported_versions = PipelineInfo.supported_versions(is_xl)
parser.add_argument(
"-v",
"--version",
type=str,
default="xl-1.0" if is_xl else "1.5",
choices=supported_versions,
help="Version of Stable Diffusion" + (" XL." if is_xl else "."),
)
parser.add_argument(
"-y",
"--height",
type=int,
default=None,
help="Height of image to generate (must be multiple of 8).",
)
parser.add_argument(
"-x", "--width", type=int, default=None, help="Height of image to generate (must be multiple of 8)."
)
parser.add_argument(
"-s",
"--scheduler",
type=str,
default=None,
choices=["DDIM", "EulerA", "UniPC", "LCM"],
help="Scheduler for diffusion process" + " of base" if is_xl else "",
)
parser.add_argument(
"-wd",
"--work-dir",
default=".",
help="Root Directory to store torch or ONNX models, built engines and output images etc.",
)
parser.add_argument(
"-i",
"--engine-dir",
default=None,
help="Root Directory to store built engines or optimized ONNX models etc.",
)
parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.")
parser.add_argument(
"-n",
"--negative-prompt",
nargs="*",
default=[""],
help="Optional negative prompt(s) to guide the image generation.",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=1,
choices=[1, 2, 4, 8, 16],
help="Number of times to repeat the prompt (batch size multiplier).",
)
parser.add_argument(
"-d",
"--denoising-steps",
type=int,
default=None,
help="Number of denoising steps" + (" in base." if is_xl else "."),
)
parser.add_argument(
"-g",
"--guidance",
type=float,
default=None,
help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.",
)
parser.add_argument(
"-ls", "--lora-scale", type=float, default=1, help="Scale of LoRA weights, default 1 (must between 0 and 1)"
)
parser.add_argument("-lw", "--lora-weights", type=str, default="", help="LoRA weights to apply in the base model")
if is_xl:
parser.add_argument(
"--lcm",
action="store_true",
help="Use fine-tuned latent consistency model to replace the UNet in base.",
)
parser.add_argument(
"-rs",
"--refiner-scheduler",
type=str,
default="EulerA",
choices=["DDIM", "EulerA", "UniPC"],
help="Scheduler for diffusion process of refiner.",
)
parser.add_argument(
"-rg",
"--refiner-guidance",
type=float,
default=5.0,
help="Guidance scale used in refiner.",
)
parser.add_argument(
"-rd",
"--refiner-denoising-steps",
type=int,
default=30,
help="Number of denoising steps in refiner. Note that actual steps is refiner_denoising_steps * strength.",
)
parser.add_argument(
"--strength",
type=float,
default=0.3,
help="A value between 0 and 1. The higher the value less the final image similar to the seed image.",
)
parser.add_argument(
"-r",
"--enable-refiner",
action="store_true",
help="Enable SDXL refiner to refine image from base pipeline.",
)
# ONNX export
parser.add_argument(
"--onnx-opset",
type=int,
default=None,
choices=range(14, 18),
help="Select ONNX opset version to target for exported models.",
)
# Engine build options.
parser.add_argument(
"-db",
"--build-dynamic-batch",
action="store_true",
help="Build TensorRT engines to support dynamic batch size.",
)
parser.add_argument(
"-ds",
"--build-dynamic-shape",
action="store_true",
help="Build TensorRT engines to support dynamic image sizes.",
)
parser.add_argument("--max-batch-size", type=int, default=None, choices=[1, 2, 4, 8, 16, 32], help="Max batch size")
# Inference related options
parser.add_argument(
"-nw", "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance."
)
parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.")
parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.")
parser.add_argument("--deterministic", action="store_true", help="use deterministic algorithms.")
parser.add_argument("-dc", "--disable-cuda-graph", action="store_true", help="Disable cuda graph.")
parser.add_argument("--framework-model-dir", default=None, help="framework model directory")
group = parser.add_argument_group("Options for ORT_CUDA engine only")
group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.")
group.add_argument("--max-cuda-graphs", type=int, default=1, help="Max number of cuda graphs to use. Default 1.")
group.add_argument("--user-compute-stream", action="store_true", help="Use user compute stream.")
# TensorRT only options
group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only")
group.add_argument(
"--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources."
)
args = parser.parse_args()
set_default_arguments(args)
# Validate image dimensions
if args.height % 64 != 0 or args.width % 64 != 0:
raise ValueError(
f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}."
)
if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph:
print("[I] CUDA Graph is disabled since dynamic input shape is configured.")
args.disable_cuda_graph = True
if args.onnx_opset is None:
args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17
if is_xl:
if args.version == "xl-turbo":
if args.lcm:
print("[I] sdxl-turbo cannot use with LCM.")
args.lcm = False
assert args.strength > 0.0 and args.strength < 1.0
assert not (args.lcm and args.lora_weights), "it is not supported to use both lcm unet and Lora together"
if args.scheduler == "LCM":
if args.guidance > 2.0:
print("[I] Use --guidance=0.0 (no more than 2.0) when LCM scheduler is used.")
args.guidance = 0.0
if args.denoising_steps > 16:
print("[I] Use --denoising_steps=8 (no more than 16) when LCM scheduler is used.")
args.denoising_steps = 8
print(args)
return args
def max_batch(args):
if args.max_batch_size:
max_batch_size = args.max_batch_size
else:
do_classifier_free_guidance = args.guidance > 1.0
batch_multiplier = 2 if do_classifier_free_guidance else 1
max_batch_size = 32 // batch_multiplier
if args.engine != "ORT_CUDA" and (args.build_dynamic_shape or args.height > 512 or args.width > 512):
max_batch_size = 8 // batch_multiplier
return max_batch_size
def get_metadata(args, is_xl: bool = False) -> dict[str, Any]:
metadata = {
"command": " ".join(['"' + x + '"' if " " in x else x for x in sys.argv]),
"args.prompt": args.prompt,
"args.negative_prompt": args.negative_prompt,
"args.batch_size": args.batch_size,
"height": args.height,
"width": args.width,
"cuda_graph": not args.disable_cuda_graph,
"vae_slicing": args.enable_vae_slicing,
"engine": args.engine,
}
if args.lora_weights:
metadata["lora_weights"] = args.lora_weights
metadata["lora_scale"] = args.lora_scale
if args.controlnet_type:
metadata["controlnet_type"] = args.controlnet_type
metadata["controlnet_scale"] = args.controlnet_scale
if is_xl and args.enable_refiner:
metadata["base.scheduler"] = args.scheduler
metadata["base.denoising_steps"] = args.denoising_steps
metadata["base.guidance"] = args.guidance
metadata["refiner.strength"] = args.strength
metadata["refiner.scheduler"] = args.refiner_scheduler
metadata["refiner.denoising_steps"] = args.refiner_denoising_steps
metadata["refiner.guidance"] = args.refiner_guidance
else:
metadata["scheduler"] = args.scheduler
metadata["denoising_steps"] = args.denoising_steps
metadata["guidance"] = args.guidance
# Version of installed python packages
packages = ""
for name in [
"onnxruntime-gpu",
"torch",
"tensorrt",
"transformers",
"diffusers",
"onnx",
"onnx-graphsurgeon",
"polygraphy",
"controlnet_aux",
]:
try:
packages += (" " if packages else "") + f"{name}=={version(name)}"
except PackageNotFoundError:
continue
metadata["packages"] = packages
metadata["device"] = torch.cuda.get_device_name()
metadata["torch.version.cuda"] = torch.version.cuda
return metadata
def repeat_prompt(args):
if not isinstance(args.prompt, list):
raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}")
prompt = args.prompt * args.batch_size
if not isinstance(args.negative_prompt, list):
raise ValueError(
f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}"
)
if len(args.negative_prompt) == 1:
negative_prompt = args.negative_prompt * len(prompt)
else:
negative_prompt = args.negative_prompt
return prompt, negative_prompt
def initialize_pipeline(
version="xl-turbo",
is_refiner: bool = False,
is_inpaint: bool = False,
engine_type=EngineType.ORT_CUDA,
work_dir: str = ".",
engine_dir=None,
onnx_opset: int = 17,
scheduler="EulerA",
height=512,
width=512,
nvtx_profile=False,
use_cuda_graph=True,
build_dynamic_batch=False,
build_dynamic_shape=False,
min_image_size: int = 512,
max_image_size: int = 1024,
max_batch_size: int = 16,
opt_batch_size: int = 1,
build_all_tactics: bool = False,
do_classifier_free_guidance: bool = False,
lcm: bool = False,
controlnet=None,
lora_weights=None,
lora_scale: float = 1.0,
use_fp16_vae: bool = True,
use_vae: bool = True,
framework_model_dir: str | None = None,
max_cuda_graphs: int = 1,
):
pipeline_info = PipelineInfo(
version,
is_refiner=is_refiner,
is_inpaint=is_inpaint,
use_vae=use_vae,
min_image_size=min_image_size,
max_image_size=max_image_size,
use_fp16_vae=use_fp16_vae,
use_lcm=lcm,
do_classifier_free_guidance=do_classifier_free_guidance,
controlnet=controlnet,
lora_weights=lora_weights,
lora_scale=lora_scale,
)
input_engine_dir = engine_dir
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type, framework_model_dir=framework_model_dir
)
pipeline = StableDiffusionPipeline(
pipeline_info,
scheduler=scheduler,
output_dir=output_dir,
verbose=False,
nvtx_profile=nvtx_profile,
max_batch_size=max_batch_size,
use_cuda_graph=use_cuda_graph,
framework_model_dir=framework_model_dir,
engine_type=engine_type,
)
import_engine_dir = None
if input_engine_dir:
if not os.path.exists(input_engine_dir):
raise RuntimeError(f"--engine_dir directory does not exist: {input_engine_dir}")
# Support importing from optimized diffusers onnx pipeline
if engine_type == EngineType.ORT_CUDA and os.path.exists(os.path.join(input_engine_dir, "model_index.json")):
import_engine_dir = input_engine_dir
else:
engine_dir = input_engine_dir
opt_image_height = pipeline_info.default_image_size() if build_dynamic_shape else height
opt_image_width = pipeline_info.default_image_size() if build_dynamic_shape else width
if engine_type == EngineType.ORT_CUDA:
pipeline.backend.build_engines(
engine_dir=engine_dir,
framework_model_dir=framework_model_dir,
onnx_dir=onnx_dir,
tmp_dir=os.path.join(work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"),
device_id=torch.cuda.current_device(),
import_engine_dir=import_engine_dir,
max_cuda_graphs=max_cuda_graphs,
)
elif engine_type == EngineType.ORT_TRT:
pipeline.backend.build_engines(
engine_dir,
framework_model_dir,
onnx_dir,
onnx_opset,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
opt_batch_size=opt_batch_size,
static_batch=not build_dynamic_batch,
static_image_shape=not build_dynamic_shape,
max_workspace_size=0,
device_id=torch.cuda.current_device(),
timing_cache=timing_cache,
)
elif engine_type == EngineType.TRT:
pipeline.backend.load_engines(
engine_dir,
framework_model_dir,
onnx_dir,
onnx_opset,
opt_batch_size=opt_batch_size,
opt_image_height=opt_image_height,
opt_image_width=opt_image_width,
static_batch=not build_dynamic_batch,
static_shape=not build_dynamic_shape,
enable_all_tactics=build_all_tactics,
timing_cache=timing_cache,
)
elif engine_type == EngineType.TORCH:
pipeline.backend.build_engines(framework_model_dir)
else:
raise RuntimeError("invalid engine type")
return pipeline
def load_pipelines(args, batch_size=None):
engine_type = get_engine_type(args.engine)
# Register TensorRT plugins
if engine_type == EngineType.TRT:
from trt_utilities import init_trt_plugins # noqa: PLC0415
init_trt_plugins()
max_batch_size = max_batch(args)
if batch_size is None:
assert isinstance(args.prompt, list)
batch_size = len(args.prompt) * args.batch_size
if batch_size > max_batch_size:
raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.")
# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
# This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
if args.version == "xl-turbo":
min_image_size = 512
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
elif args.version == "xl-1.0":
min_image_size = 832 if args.engine != "ORT_CUDA" else 512
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048
else:
# This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768.
min_image_size = 512 if args.engine != "ORT_CUDA" else 256
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
params = {
"version": args.version,
"is_refiner": False,
"is_inpaint": False,
"engine_type": engine_type,
"work_dir": args.work_dir,
"engine_dir": args.engine_dir,
"onnx_opset": args.onnx_opset,
"scheduler": args.scheduler,
"height": args.height,
"width": args.width,
"nvtx_profile": args.nvtx_profile,
"use_cuda_graph": not args.disable_cuda_graph,
"build_dynamic_batch": args.build_dynamic_batch,
"build_dynamic_shape": args.build_dynamic_shape,
"min_image_size": min_image_size,
"max_image_size": max_image_size,
"max_batch_size": max_batch_size,
"opt_batch_size": 1 if args.build_dynamic_batch else batch_size,
"build_all_tactics": args.build_all_tactics,
"do_classifier_free_guidance": args.guidance > 1.0,
"controlnet": args.controlnet_type,
"lora_weights": args.lora_weights,
"lora_scale": args.lora_scale,
"use_fp16_vae": "xl" in args.version,
"use_vae": True,
"framework_model_dir": args.framework_model_dir,
"max_cuda_graphs": args.max_cuda_graphs,
}
if "xl" in args.version:
params["lcm"] = args.lcm
params["use_vae"] = not args.enable_refiner
base = initialize_pipeline(**params)
refiner = None
if "xl" in args.version and args.enable_refiner:
params["version"] = "xl-1.0" # Allow SDXL Turbo to use refiner.
params["is_refiner"] = True
params["scheduler"] = args.refiner_scheduler
params["do_classifier_free_guidance"] = args.refiner_guidance > 1.0
params["lcm"] = False
params["controlnet"] = None
params["lora_weights"] = None
params["use_vae"] = True
params["use_fp16_vae"] = True
refiner = initialize_pipeline(**params)
if engine_type == EngineType.TRT:
max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory())
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
base.backend.activate_engines(shared_device_memory)
if refiner:
refiner.backend.activate_engines(shared_device_memory)
if engine_type == EngineType.ORT_CUDA:
enable_vae_slicing = args.enable_vae_slicing
if batch_size > 4 and not enable_vae_slicing and (args.height >= 1024 and args.width >= 1024):
print(
"Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4 and resolution >= 1024."
)
enable_vae_slicing = True
if enable_vae_slicing:
(refiner or base).backend.enable_vae_slicing()
return base, refiner
def get_depth_image(image):
"""
Create depth map for SDXL depth control net.
"""
from transformers import DPTFeatureExtractor, DPTForDepthEstimation # noqa: PLC0415
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
with torch.no_grad(), torch.autocast("cuda"):
depth_map = depth_estimator(image).predicted_depth
# The depth map is 384x384 by default, here we interpolate to the default output size.
# Note that it will be resized to output image size later. May change the size here to avoid interpolate twice.
depth_map = torch.nn.functional.interpolate(
depth_map.unsqueeze(1),
size=(1024, 1024),
mode="bicubic",
align_corners=False,
)
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
image = torch.cat([depth_map] * 3, dim=1)
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
return image
def get_canny_image(image) -> Image.Image:
"""
Create canny image for SDXL control net.
"""
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)
return image
def process_controlnet_images_xl(args) -> list[Image.Image]:
"""
Process control image for SDXL control net.
"""
assert len(args.controlnet_image) == 1
image = Image.open(args.controlnet_image[0]).convert("RGB")
controlnet_images = []
if args.controlnet_type[0] == "canny":
controlnet_images.append(get_canny_image(image))
elif args.controlnet_type[0] == "depth":
controlnet_images.append(get_depth_image(image))
else:
raise ValueError(f"This controlnet type is not supported for SDXL or Turbo: {args.controlnet_type}.")
return controlnet_images
def add_controlnet_arguments(parser, is_xl: bool = False):
"""
Add control net related arguments.
"""
group = parser.add_argument_group("Options for ControlNet (supports 1.5, sd-turbo, xl-turbo, xl-1.0).")
group.add_argument(
"-ci",
"--controlnet-image",
nargs="*",
type=str,
default=[],
help="Path to the input regular RGB image/images for controlnet",
)
group.add_argument(
"-ct",
"--controlnet-type",
nargs="*",
type=str,
default=[],
choices=list(PipelineInfo.supported_controlnet("xl-1.0" if is_xl else "1.5").keys()),
help="A list of controlnet type",
)
group.add_argument(
"-cs",
"--controlnet-scale",
nargs="*",
type=float,
default=[],
help="The outputs of the controlnet are multiplied by `controlnet_scale` before they are added to the residual in the original unet. Default is 0.5 for SDXL, or 1.0 for SD 1.5",
)
def process_controlnet_image(controlnet_type: str, image: Image.Image, height, width):
"""
Process control images of control net v1.1 for Stable Diffusion 1.5.
"""
control_image = None
shape = (height, width)
image = image.convert("RGB")
if controlnet_type == "canny":
canny_image = controlnet_aux.CannyDetector()(image)
control_image = canny_image.resize(shape)
elif controlnet_type == "normalbae":
normal_image = controlnet_aux.NormalBaeDetector.from_pretrained("lllyasviel/Annotators")(image)
control_image = normal_image.resize(shape)
elif controlnet_type == "depth":
depth_image = controlnet_aux.LeresDetector.from_pretrained("lllyasviel/Annotators")(image)
control_image = depth_image.resize(shape)
elif controlnet_type == "mlsd":
mlsd_image = controlnet_aux.MLSDdetector.from_pretrained("lllyasviel/Annotators")(image)
control_image = mlsd_image.resize(shape)
elif controlnet_type == "openpose":
openpose_image = controlnet_aux.OpenposeDetector.from_pretrained("lllyasviel/Annotators")(image)
control_image = openpose_image.resize(shape)
elif controlnet_type == "scribble":
scribble_image = controlnet_aux.HEDdetector.from_pretrained("lllyasviel/Annotators")(image, scribble=True)
control_image = scribble_image.resize(shape)
elif controlnet_type == "seg":
seg_image = controlnet_aux.SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")(
image
)
control_image = seg_image.resize(shape)
else:
raise ValueError(f"There is no demo image of this controlnet_type: {controlnet_type}")
return control_image
def process_controlnet_arguments(args):
"""
Process control net arguments, and returns a list of control images and a tensor of control net scales.
"""
assert isinstance(args.controlnet_type, list)
assert isinstance(args.controlnet_scale, list)
assert isinstance(args.controlnet_image, list)
if len(args.controlnet_image) != len(args.controlnet_type):
raise ValueError(
f"Numbers of controlnet_image {len(args.controlnet_image)} should be equal to number of controlnet_type {len(args.controlnet_type)}."
)
if len(args.controlnet_type) == 0:
return None, None
if args.version not in ["1.5", "xl-1.0", "xl-turbo", "sd-turbo"]:
raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5, XL or Turbo.")
is_xl = "xl" in args.version
if is_xl and len(args.controlnet_type) > 1:
raise ValueError("This demo only support one ControlNet for Stable Diffusion XL or Turbo.")
if len(args.controlnet_scale) == 0:
args.controlnet_scale = [0.5 if is_xl else 1.0] * len(args.controlnet_type)
elif len(args.controlnet_type) != len(args.controlnet_scale):
raise ValueError(
f"Numbers of controlnet_type {len(args.controlnet_type)} should be equal to number of controlnet_scale {len(args.controlnet_scale)}."
)
# Convert controlnet scales to tensor
controlnet_scale = torch.FloatTensor(args.controlnet_scale)
if is_xl:
images = process_controlnet_images_xl(args)
else:
images = []
for i, image in enumerate(args.controlnet_image):
images.append(process_controlnet_image(args.controlnet_type[i], Image.open(image), args.height, args.width))
return images, controlnet_scale
@@ -0,0 +1,295 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import hashlib
import os
from enum import Enum
import torch
from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL
class EngineType(Enum):
ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider
ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider
TRT = 2 # TensorRT
TORCH = 3 # PyTorch
def get_engine_type(name: str) -> EngineType:
name_to_type = {
"ORT_CUDA": EngineType.ORT_CUDA,
"ORT_TRT": EngineType.ORT_TRT,
"TRT": EngineType.TRT,
"TORCH": EngineType.TORCH,
}
return name_to_type[name]
class EngineBuilder:
def __init__(
self,
engine_type: EngineType,
pipeline_info: PipelineInfo,
device="cuda",
max_batch_size=16,
use_cuda_graph=False,
):
"""
Initializes the Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
device (str | torch.device):
device to run engine
max_batch_size (int):
Maximum batch size for dynamic batch engine.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
self.engine_type = engine_type
self.pipeline_info = pipeline_info
self.max_batch_size = max_batch_size
self.use_cuda_graph = use_cuda_graph
self.device = torch.device(device)
self.torch_device = torch.device(device, torch.cuda.current_device())
self.stages = pipeline_info.stages()
self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() and self.engine_type != EngineType.TORCH
self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae()
self.models = {}
self.engines = {}
self.torch_models = {}
self.use_vae_slicing = False
self.torch_sdpa = getattr(torch.nn.functional, "scaled_dot_product_attention", None)
def enable_vae_slicing(self):
self.use_vae_slicing = True
def disable_torch_spda(self):
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
delattr(torch.nn.functional, "scaled_dot_product_attention")
def enable_torch_spda(self):
if (not hasattr(torch.nn.functional, "scaled_dot_product_attention")) and self.torch_sdpa:
torch.nn.functional.scaled_dot_product_attention = self.torch_sdpa
def teardown(self):
for engine in self.engines.values():
del engine
self.engines = {}
def get_diffusers_module_name(self, model_name):
name_mapping = {
"clip": "text_encoder",
"clip2": "text_encoder_2",
"unet": "unet",
"unetxl": "unet",
"vae": "vae_decoder",
}
return name_mapping.get(model_name, model_name)
def get_cached_model_name(self, model_name):
model_name = self.get_diffusers_module_name(model_name)
is_unet = model_name == "unet"
hash_source = []
if model_name in ["text_encoder", "text_encoder_2", "unet"] and self.pipeline_info.lora_weights:
if self.pipeline_info.lora_weights in [
"latent-consistency/lcm-lora-sdxl",
"latent-consistency/lcm-lora-sdv1-5",
]:
if is_unet:
model_name = "unet_lcm-lora"
else:
model_name = model_name + "_lora"
hash_source.append(self.pipeline_info.lora_weights)
# TODO(tianleiwu): save custom model to a directory named by its original model.
if is_unet and self.pipeline_info.custom_unet():
model_name = model_name + "_lcm"
if model_name in ["unet"] and self.pipeline_info.controlnet:
model_name = model_name + "_" + "_".join(self.pipeline_info.controlnet)
if hash_source:
model_name += "_" + hashlib.sha256("\t".join(hash_source).encode("utf-8")).hexdigest()[:8]
# TODO: When we support original VAE, we shall save custom VAE to another directory.
if self.pipeline_info.is_inpaint():
model_name += "_inpaint"
return model_name
def get_model_dir(self, model_name, root_dir, opt=True, suffix="", create=True):
engine_name = self.engine_type.name.lower()
if engine_name != "ort_cuda" and not suffix:
suffix = f".{engine_name}" if opt else ""
directory_name = self.get_cached_model_name(model_name) + suffix
onnx_model_dir = os.path.join(root_dir, directory_name)
if create:
os.makedirs(onnx_model_dir, exist_ok=True)
return onnx_model_dir
def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""):
onnx_model_dir = self.get_model_dir(model_name, onnx_dir, opt=opt, suffix=suffix)
return os.path.join(onnx_model_dir, "model.onnx")
def get_engine_path(self, engine_dir, model_name, profile_id):
return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id)
def load_pipeline_with_lora(self):
"""Load text encoders and UNet with diffusers pipeline"""
from diffusers import DiffusionPipeline # noqa: PLC0415
pipeline = DiffusionPipeline.from_pretrained(
self.pipeline_info.name(),
variant="fp16",
torch_dtype=torch.float16,
)
pipeline.load_lora_weights(self.pipeline_info.lora_weights)
pipeline.fuse_lora(lora_scale=self.pipeline_info.lora_scale)
del pipeline.vae
pipeline.vae = None
return pipeline
def get_or_load_model(self, pipeline, model_name, model_obj, framework_model_dir):
if model_name in ["clip", "clip2", "unet", "unetxl"] and pipeline:
if model_name == "clip":
model = pipeline.text_encoder
pipeline.text_encoder = None
elif model_name == "clip2":
model = pipeline.text_encoder_2
pipeline.text_encoder_2 = None
else:
model = pipeline.unet
pipeline.unet = None
else:
model = model_obj.load_model(framework_model_dir)
return model.to(self.torch_device)
def load_models(self, framework_model_dir: str):
# For TRT or ORT_TRT, we will export fp16 torch model for UNet and VAE
# For ORT_CUDA, we export fp32 model first, then optimize to fp16.
export_fp16 = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT]
if "clip" in self.stages:
self.models["clip"] = CLIP(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
max_batch_size=self.max_batch_size,
clip_skip=0,
)
if "clip2" in self.stages:
self.models["clip2"] = CLIPWithProj(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
max_batch_size=self.max_batch_size,
clip_skip=0,
)
if "unet" in self.stages:
self.models["unet"] = UNet(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
fp16=export_fp16,
max_batch_size=self.max_batch_size,
unet_dim=(9 if self.pipeline_info.is_inpaint() else 4),
)
if "unetxl" in self.stages:
self.models["unetxl"] = UNetXL(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
fp16=export_fp16,
max_batch_size=self.max_batch_size,
unet_dim=4,
time_dim=(5 if self.pipeline_info.is_xl_refiner() else 6),
)
# VAE Decoder
if "vae" in self.stages:
self.models["vae"] = VAE(
self.pipeline_info,
None, # not loaded yet
device=self.torch_device,
max_batch_size=self.max_batch_size,
fp16=export_fp16,
custom_fp16_vae=self.custom_fp16_vae,
)
if self.vae_torch_fallback:
self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir)
def load_resources(self, image_height, image_width, batch_size):
if self.engine_type == EngineType.TORCH:
return
# Allocate buffers for I/O bindings
for model_name, obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
slice_size = 1 if (model_name == "vae" and self.use_vae_slicing) else batch_size
self.engines[model_name].allocate_buffers(
shape_dict=obj.get_shape_dict(slice_size, image_height, image_width), device=self.torch_device
)
def _vae_decode(self, latents):
if self.engine_type == EngineType.TORCH:
if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
latents = latents.to(dtype=torch.float32)
images = self.engines["vae"](latents)["sample"]
else:
images = self.engines["vae"](latents)["sample"]
elif self.vae_torch_fallback:
if not self.custom_fp16_vae:
latents = latents.to(dtype=torch.float32)
self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32)
images = self.torch_models["vae"](latents)["sample"]
else:
if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
images = self.run_engine("vae", {"latent": latents.to(dtype=torch.float32)})["images"]
else:
images = self.run_engine("vae", {"latent": latents})["images"]
return images
def vae_decode(self, latents):
if self.use_vae_slicing:
# The output tensor points to same buffer. Need clone it to avoid overwritten.
decoded_slices = [self._vae_decode(z_slice).clone() for z_slice in latents.split(1)]
return torch.cat(decoded_slices)
return self._vae_decode(latents)
def get_engine_paths(
work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: str | None = None
):
root_dir = work_dir or "."
short_name = pipeline_info.short_name()
# When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since
# ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model.
onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx")
engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine")
output_dir = os.path.join(root_dir, engine_type.name, short_name, "output")
timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache")
# Shared among ORT_CUDA, ORT_TRT and TRT engines, and need use load_model(..., always_download_fp16=True)
# So that the shared model is always fp16.
if framework_model_dir is None:
framework_model_dir = os.path.join(root_dir, "torch_model")
return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache
@@ -0,0 +1,387 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import gc
import logging
import os
import onnx
import torch
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
from packaging import version
import onnxruntime as ort
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
from onnxruntime.transformers.onnx_model import OnnxModel
logger = logging.getLogger(__name__)
class OrtCudaEngine:
def __init__(
self,
onnx_path,
device_id: int = 0,
enable_cuda_graph: bool = False,
disable_optimization: bool = False,
max_cuda_graphs: int = 1,
):
self.onnx_path = onnx_path
self.provider = "CUDAExecutionProvider"
self.stream = torch.cuda.current_stream().cuda_stream
self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph, self.stream)
session_options = ort.SessionOptions()
# When the model has been optimized by onnxruntime, we can disable optimization to save session creation time.
if disable_optimization:
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
logger.info("creating CUDA EP session for %s", onnx_path)
ort_session = ort.InferenceSession(
onnx_path,
session_options,
providers=[
(self.provider, self.provider_options),
"CPUExecutionProvider",
],
)
logger.info("created CUDA EP session for %s", onnx_path)
device = torch.device("cuda", device_id)
self.enable_cuda_graph = enable_cuda_graph
# Support multiple CUDA graphs for different input shapes.
# For clip2 model that disabled cuda graph, max_cuda_graphs is updated to 0 here.
self.gpu_binding_manager = GpuBindingManager(
ort_session=ort_session,
device=device,
stream=self.stream,
max_cuda_graphs=max_cuda_graphs if enable_cuda_graph else 0,
)
self.current_gpu_binding = None
def metadata(self, name: str):
data = {}
if self.current_gpu_binding is not None:
if self.current_gpu_binding.last_run_gpu_graph_id >= 0:
data[f"{name}.gpu_graph_id"] = self.current_gpu_binding.last_run_gpu_graph_id
return data
def infer(self, feed_dict: dict[str, torch.Tensor]):
return self.current_gpu_binding.infer(feed_dict=feed_dict, disable_cuda_graph_in_run=not self.enable_cuda_graph)
def allocate_buffers(self, shape_dict, device):
self.current_gpu_binding = self.gpu_binding_manager.get_binding(
shape_dict=shape_dict, use_cuda_graph=self.enable_cuda_graph
)
class _ModelConfig:
"""
Configuration of one model (like Clip, UNet etc) on ONNX export and optimization for CUDA provider.
For example, if you want to use fp32 in layer normalization, set the following:
force_fp32_ops=["SkipLayerNormalization", "LayerNormalization"]
"""
def __init__(
self,
onnx_opset_version: int,
use_cuda_graph: bool,
fp16: bool = True,
force_fp32_ops: list[str] | None = None,
optimize_by_ort: bool = True,
):
self.onnx_opset_version = onnx_opset_version
self.use_cuda_graph = use_cuda_graph
self.fp16 = fp16
self.force_fp32_ops = force_fp32_ops
self.optimize_by_ort = optimize_by_ort
class OrtCudaEngineBuilder(EngineBuilder):
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.ORT_CUDA,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)
self.model_config = {}
def _configure(
self,
model_name: str,
onnx_opset_version: int,
use_cuda_graph: bool,
fp16: bool = True,
force_fp32_ops: list[str] | None = None,
optimize_by_ort: bool = True,
):
self.model_config[model_name] = _ModelConfig(
onnx_opset_version,
use_cuda_graph,
fp16=fp16,
force_fp32_ops=force_fp32_ops,
optimize_by_ort=optimize_by_ort,
)
def configure_xl(self, onnx_opset_version: int):
self._configure(
"clip",
onnx_opset_version=onnx_opset_version,
use_cuda_graph=self.use_cuda_graph,
)
self._configure(
"clip2",
onnx_opset_version=onnx_opset_version, # TODO: ArgMax-12 is not implemented in CUDA
use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph
)
self._configure(
"unetxl",
onnx_opset_version=onnx_opset_version,
use_cuda_graph=self.use_cuda_graph,
)
self._configure(
"vae",
onnx_opset_version=onnx_opset_version,
use_cuda_graph=self.use_cuda_graph,
)
def optimized_onnx_path(self, engine_dir, model_name):
suffix = "" if self.model_config[model_name].fp16 else ".fp32"
return self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix)
def import_diffusers_engine(self, diffusers_onnx_dir: str, engine_dir: str):
"""Import optimized onnx models for diffusers from Olive or optimize_pipeline tools.
Args:
diffusers_onnx_dir (str): optimized onnx directory of Olive
engine_dir (str): the directory to store imported onnx
"""
if version.parse(ort.__version__) < version.parse("1.17.0"):
print("Skip importing since onnxruntime-gpu version < 1.17.0.")
return
for model_name, model_obj in self.models.items():
onnx_import_path = self.optimized_onnx_path(diffusers_onnx_dir, model_name)
if not os.path.exists(onnx_import_path):
print(f"{onnx_import_path} not existed. Skip importing.")
continue
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
if os.path.exists(onnx_opt_path):
print(f"{onnx_opt_path} existed. Skip importing.")
continue
if model_name == "vae" and self.pipeline_info.is_xl():
print(f"Skip importing VAE since it is not fully compatible with float16: {onnx_import_path}.")
continue
model = OnnxModel(onnx.load(onnx_import_path, load_external_data=True))
if model_name in ["clip", "clip2"]:
hidden_states_per_layer = []
for output in model.graph().output:
if output.name.startswith("hidden_states."):
hidden_states_per_layer.append(output.name)
if hidden_states_per_layer:
kept_hidden_states = hidden_states_per_layer[-2 - model_obj.clip_skip]
model.rename_graph_output(kept_hidden_states, "hidden_states")
model.rename_graph_output(
"last_hidden_state" if model_name == "clip" else "text_embeds", "text_embeddings"
)
model.prune_graph(
["text_embeddings", "hidden_states"] if hidden_states_per_layer else ["text_embeddings"]
)
if model_name == "clip2":
model.change_graph_input_type(model.find_graph_input("input_ids"), onnx.TensorProto.INT32)
model.save_model_to_file(onnx_opt_path, use_external_data_format=(model_name == "clip2"))
elif model_name in ["unet", "unetxl"]:
model.rename_graph_output("out_sample", "latent")
model.save_model_to_file(onnx_opt_path, use_external_data_format=True)
del model
continue
def build_engines(
self,
engine_dir: str,
framework_model_dir: str,
onnx_dir: str,
tmp_dir: str | None = None,
onnx_opset_version: int = 17,
device_id: int = 0,
save_fp32_intermediate_model: bool = False,
import_engine_dir: str | None = None,
max_cuda_graphs: int = 1,
):
self.torch_device = torch.device("cuda", device_id)
self.load_models(framework_model_dir)
if not os.path.isdir(engine_dir):
os.makedirs(engine_dir)
if not os.path.isdir(onnx_dir):
os.makedirs(onnx_dir)
# Add default configuration if missing
if self.pipeline_info.is_xl():
self.configure_xl(onnx_opset_version)
for model_name in self.models:
if model_name not in self.model_config:
self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph)
# Import Engine
if import_engine_dir:
if self.pipeline_info.is_xl():
self.import_diffusers_engine(import_engine_dir, engine_dir)
else:
print(f"Only support importing SDXL onnx. Ignore --engine-dir {import_engine_dir}")
# Load lora only when we need export text encoder or UNet to ONNX.
load_lora = False
if self.pipeline_info.lora_weights:
for model_name in self.models:
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
continue
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
load_lora = True
break
# Export models to ONNX
self.disable_torch_spda()
pipe = self.load_pipeline_with_lora() if load_lora else None
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
print("----")
logger.info("Exporting model: %s", onnx_path)
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
model = model.to(torch.float32)
with torch.inference_mode():
# For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern.
# Export model with sample of batch size 1, image size 512 x 512
inputs = model_obj.get_sample_input(1, 512, 512)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=self.model_config[model_name].onnx_opset_version,
do_constant_folding=True,
input_names=model_obj.get_input_names(),
output_names=model_obj.get_output_names(),
dynamic_axes=model_obj.get_dynamic_axes(),
)
del model
torch.cuda.empty_cache()
gc.collect()
else:
logger.info("Found cached model: %s", onnx_path)
# Generate fp32 optimized model.
# If final target is fp16 model, we save fp32 optimized model so that it is easy to tune
# fp16 conversion. That could save a lot of time in developing.
use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16
onnx_fp32_path = onnx_path
if use_fp32_intermediate:
onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32")
if not os.path.exists(onnx_fp32_path):
print("------")
logger.info("Generating optimized model: %s", onnx_fp32_path)
model_obj.optimize_ort(
onnx_path,
onnx_fp32_path,
to_fp16=False,
fp32_op_list=self.model_config[model_name].force_fp32_ops,
optimize_by_ort=self.model_config[model_name].optimize_by_ort,
tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".fp32", create=False),
)
else:
logger.info("Found cached optimized model: %s", onnx_fp32_path)
# Generate the final optimized model.
if not os.path.exists(onnx_opt_path):
print("------")
logger.info("Generating optimized model: %s", onnx_opt_path)
# When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16.
optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort
model_obj.optimize_ort(
onnx_fp32_path,
onnx_opt_path,
to_fp16=self.model_config[model_name].fp16,
fp32_op_list=self.model_config[model_name].force_fp32_ops,
optimize_by_ort=optimize_by_ort,
optimize_by_fusion=not use_fp32_intermediate,
tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".ort", create=False),
)
else:
logger.info("Found cached optimized model: %s", onnx_opt_path)
self.enable_torch_spda()
built_engines = {}
for model_name in self.models:
if model_name == "vae" and self.vae_torch_fallback:
continue
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
use_cuda_graph = self.model_config[model_name].use_cuda_graph
engine = OrtCudaEngine(
onnx_opt_path,
device_id=device_id,
enable_cuda_graph=use_cuda_graph,
disable_optimization=False,
max_cuda_graphs=max_cuda_graphs,
)
logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options)
built_engines[model_name] = engine
self.engines = built_engines
def run_engine(self, model_name, feed_dict):
return self.engines[model_name].infer(feed_dict)
@@ -0,0 +1,288 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import gc
import logging
import os
import torch
from cuda import cudart
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
from packaging import version
import onnxruntime as ort
from onnxruntime.transformers.io_binding_helper import CudaSession
logger = logging.getLogger(__name__)
class OrtTensorrtEngine(CudaSession):
def __init__(
self,
engine_path,
device_id,
onnx_path,
fp16,
input_profile,
workspace_size,
enable_cuda_graph,
timing_cache_path=None,
):
self.engine_path = engine_path
self.ort_trt_provider_options = self.get_tensorrt_provider_options(
input_profile,
workspace_size,
fp16,
device_id,
enable_cuda_graph,
timing_cache_path=timing_cache_path,
)
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
logger.info("creating TRT EP session for %s", onnx_path)
ort_session = ort.InferenceSession(
onnx_path,
session_options,
providers=[
("TensorrtExecutionProvider", self.ort_trt_provider_options),
],
)
logger.info("created TRT EP session for %s", onnx_path)
device = torch.device("cuda", device_id)
super().__init__(ort_session, device, enable_cuda_graph)
def get_tensorrt_provider_options(
self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph, timing_cache_path=None
):
trt_ep_options = {
"device_id": device_id,
"trt_fp16_enable": fp16,
"trt_engine_cache_enable": True,
"trt_timing_cache_enable": True,
"trt_detailed_build_log": True,
"trt_engine_cache_path": self.engine_path,
}
if version.parse(ort.__version__) > version.parse("1.16.2") and timing_cache_path is not None:
trt_ep_options["trt_timing_cache_path"] = timing_cache_path
if enable_cuda_graph:
trt_ep_options["trt_cuda_graph_enable"] = True
if workspace_size > 0:
trt_ep_options["trt_max_workspace_size"] = workspace_size
if input_profile:
min_shapes = []
max_shapes = []
opt_shapes = []
for name, profile in input_profile.items():
assert isinstance(profile, list) and len(profile) == 3
min_shape = profile[0]
opt_shape = profile[1]
max_shape = profile[2]
assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape)
min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape]))
opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape]))
max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape]))
trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes)
trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes)
trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes)
logger.info("trt_ep_options=%s", trt_ep_options)
return trt_ep_options
def allocate_buffers(self, shape_dict, device):
super().allocate_buffers(shape_dict)
class OrtTensorrtEngineBuilder(EngineBuilder):
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.ORT_TRT,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)
def has_engine_file(self, engine_path):
if os.path.isdir(engine_path):
children = os.scandir(engine_path)
for entry in children:
if entry.is_file() and entry.name.endswith(".engine"):
return True
return False
def get_work_space_size(self, model_name, max_workspace_size):
gibibyte = 2**30
workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size
if workspace_size == 0:
_, free_mem, _ = cudart.cudaMemGetInfo()
# The following logic are adopted from TensorRT demo diffusion.
if free_mem > 6 * gibibyte:
workspace_size = free_mem - 4 * gibibyte
return workspace_size
def build_engines(
self,
engine_dir,
framework_model_dir,
onnx_dir,
onnx_opset,
opt_image_height,
opt_image_width,
opt_batch_size=1,
static_batch=False,
static_image_shape=True,
max_workspace_size=0,
device_id=0,
timing_cache=None,
):
self.torch_device = torch.device("cuda", device_id)
self.load_models(framework_model_dir)
if not os.path.isdir(engine_dir):
os.makedirs(engine_dir)
if not os.path.isdir(onnx_dir):
os.makedirs(onnx_dir)
# Load lora only when we need export text encoder or UNet to ONNX.
load_lora = False
if self.pipeline_info.lora_weights:
for model_name, model_obj in self.models.items():
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
if not self.has_engine_file(engine_path):
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
load_lora = True
break
# Export models to ONNX
self.disable_torch_spda()
pipe = self.load_pipeline_with_lora() if load_lora else None
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
if not self.has_engine_file(engine_path):
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
logger.info(f"Exporting model: {onnx_path}")
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
with torch.inference_mode(), torch.autocast("cuda"):
inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=onnx_opset,
do_constant_folding=True,
input_names=model_obj.get_input_names(),
output_names=model_obj.get_output_names(),
dynamic_axes=model_obj.get_dynamic_axes(),
)
del model
torch.cuda.empty_cache()
gc.collect()
else:
logger.info("Found cached model: %s", onnx_path)
# Optimize onnx
if not os.path.exists(onnx_opt_path):
logger.info("Generating optimizing model: %s", onnx_opt_path)
model_obj.optimize_trt(onnx_path, onnx_opt_path)
else:
logger.info("Found cached optimized model: %s", onnx_opt_path)
self.enable_torch_spda()
built_engines = {}
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not self.has_engine_file(engine_path):
logger.info(
"Building TensorRT engine for %s from %s to %s. It can take a while to complete...",
model_name,
onnx_opt_path,
engine_path,
)
else:
logger.info("Reuse cached TensorRT engine in directory %s", engine_path)
input_profile = model_obj.get_input_profile(
opt_batch_size,
opt_image_height,
opt_image_width,
static_batch=static_batch,
static_image_shape=static_image_shape,
)
engine = OrtTensorrtEngine(
engine_path,
device_id,
onnx_opt_path,
fp16=True,
input_profile=input_profile,
workspace_size=self.get_work_space_size(model_name, max_workspace_size),
enable_cuda_graph=self.use_cuda_graph,
timing_cache_path=timing_cache,
)
built_engines[model_name] = engine
self.engines = built_engines
def run_engine(self, model_name, feed_dict):
return self.engines[model_name].infer(feed_dict)
@@ -0,0 +1,395 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import gc
import os
import pathlib
from collections import OrderedDict
import numpy as np
import tensorrt as trt
import torch
from cuda import cudart
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import (
CreateConfig,
ModifyNetworkOutputs,
Profile,
engine_from_bytes,
engine_from_network,
network_from_onnx_path,
save_engine,
)
# Map of numpy dtype -> torch dtype
numpy_to_torch_dtype_dict = {
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
}
def _cuda_assert(cuda_ret):
err = cuda_ret[0]
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(
f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
)
if len(cuda_ret) > 1:
return cuda_ret[1]
return None
class TensorrtEngine:
def __init__(
self,
engine_path,
):
self.engine_path = engine_path
self.engine = None
self.context = None
self.buffers = OrderedDict()
self.tensors = OrderedDict()
self.cuda_graph_instance = None
def __del__(self):
del self.engine
del self.context
del self.buffers
del self.tensors
def build(
self,
onnx_path,
fp16,
input_profile=None,
enable_all_tactics=False,
timing_cache=None,
update_output_names=None,
):
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
if input_profile:
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
config_kwargs = {}
if not enable_all_tactics:
config_kwargs["tactic_sources"] = []
network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
if update_output_names:
print(f"Updating network outputs to {update_output_names}")
network = ModifyNetworkOutputs(network, update_output_names)
engine = engine_from_network(
network,
config=CreateConfig(
fp16=fp16, refittable=False, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
),
save_timing_cache=timing_cache,
)
save_engine(engine, path=self.engine_path)
def load(self):
print(f"Loading TensorRT engine: {self.engine_path}")
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
def activate(self, reuse_device_memory=None):
if reuse_device_memory:
self.context = self.engine.create_execution_context_without_device_memory()
self.context.device_memory = reuse_device_memory
else:
self.context = self.engine.create_execution_context()
def allocate_buffers(self, shape_dict=None, device="cuda"):
for idx in range(self.engine.num_io_tensors):
binding = self.engine[idx]
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_binding_shape(binding)
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[binding] = tensor
def infer(self, feed_dict, stream, use_cuda_graph=False):
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
if use_cuda_graph:
if self.cuda_graph_instance is not None:
_cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
_cuda_assert(cudart.cudaStreamSynchronize(stream))
else:
# do inference before CUDA graph capture
noerror = self.context.execute_async_v3(stream)
if not noerror:
raise ValueError("ERROR: inference failed.")
# capture cuda graph
_cuda_assert(
cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
)
self.context.execute_async_v3(stream)
self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream))
from cuda import nvrtc # noqa: PLC0415
result, major, minor = nvrtc.nvrtcVersion()
assert result == nvrtc.nvrtcResult(0)
if major < 12:
self.cuda_graph_instance = _cuda_assert(
cudart.cudaGraphInstantiate(self.graph, b"", 0)
) # cuda < 12
else:
self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12
else:
noerror = self.context.execute_async_v3(stream)
if not noerror:
raise ValueError("ERROR: inference failed.")
return self.tensors
class TensorrtEngineBuilder(EngineBuilder):
"""
Helper class to hide the detail of TensorRT Engine from pipeline.
"""
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.TRT,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)
self.stream = None
self.shared_device_memory = None
def load_resources(self, image_height, image_width, batch_size):
super().load_resources(image_height, image_width, batch_size)
self.stream = _cuda_assert(cudart.cudaStreamCreate())
def teardown(self):
super().teardown()
if self.shared_device_memory:
cudart.cudaFree(self.shared_device_memory)
cudart.cudaStreamDestroy(self.stream)
del self.stream
def load_engines(
self,
engine_dir,
framework_model_dir,
onnx_dir,
onnx_opset,
opt_batch_size,
opt_image_height,
opt_image_width,
static_batch=False,
static_shape=True,
enable_all_tactics=False,
timing_cache=None,
):
"""
Build and load engines for TensorRT accelerated inference.
Export ONNX models first, if applicable.
Args:
engine_dir (str):
Directory to write the TensorRT engines.
framework_model_dir (str):
Directory to write the framework model ckpt.
onnx_dir (str):
Directory to write the ONNX models.
onnx_opset (int):
ONNX opset version to export the models.
opt_batch_size (int):
Batch size to optimize for during engine building.
opt_image_height (int):
Image height to optimize for during engine building. Must be a multiple of 8.
opt_image_width (int):
Image width to optimize for during engine building. Must be a multiple of 8.
static_batch (bool):
Build engine only for specified opt_batch_size.
static_shape (bool):
Build engine only for specified opt_image_height & opt_image_width. Default = True.
enable_all_tactics (bool):
Enable all tactic sources during TensorRT engine builds.
timing_cache (str):
Path to the timing cache to accelerate build or None
"""
# Create directory
for directory in [engine_dir, onnx_dir]:
if not os.path.exists(directory):
print(f"[I] Create directory: {directory}")
pathlib.Path(directory).mkdir(parents=True)
self.load_models(framework_model_dir)
# Load lora only when we need export text encoder or UNet to ONNX.
load_lora = False
if self.pipeline_info.lora_weights:
for model_name, model_obj in self.models.items():
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
if not os.path.exists(engine_path):
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
load_lora = True
break
# Export models to ONNX
self.disable_torch_spda()
pipe = self.load_pipeline_with_lora() if load_lora else None
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
if not os.path.exists(engine_path):
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(onnx_opt_path):
if not os.path.exists(onnx_path):
print(f"Exporting model: {onnx_path}")
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
with torch.inference_mode(), torch.autocast("cuda"):
inputs = model_obj.get_sample_input(1, opt_image_height, opt_image_width)
torch.onnx.export(
model,
inputs,
onnx_path,
export_params=True,
opset_version=onnx_opset,
do_constant_folding=True,
input_names=model_obj.get_input_names(),
output_names=model_obj.get_output_names(),
dynamic_axes=model_obj.get_dynamic_axes(),
)
del model
torch.cuda.empty_cache()
gc.collect()
else:
print(f"Found cached model: {onnx_path}")
# Optimize onnx
if not os.path.exists(onnx_opt_path):
print(f"Generating optimizing model: {onnx_opt_path}")
model_obj.optimize_trt(onnx_path, onnx_opt_path)
else:
print(f"Found cached optimized model: {onnx_opt_path} ")
self.enable_torch_spda()
# Build TensorRT engines
for model_name, model_obj in self.models.items():
if model_name == "vae" and self.vae_torch_fallback:
continue
profile_id = model_obj.get_profile_id(
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
)
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
engine = TensorrtEngine(engine_path)
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
if not os.path.exists(engine.engine_path):
engine.build(
onnx_opt_path,
fp16=True,
input_profile=model_obj.get_input_profile(
opt_batch_size,
opt_image_height,
opt_image_width,
static_batch,
static_shape,
),
enable_all_tactics=enable_all_tactics,
timing_cache=timing_cache,
update_output_names=None,
)
self.engines[model_name] = engine
# Load TensorRT engines
for model_name in self.models:
if model_name == "vae" and self.vae_torch_fallback:
continue
self.engines[model_name].load()
def max_device_memory(self):
max_device_memory = 0
for engine in self.engines.values():
max_device_memory = max(max_device_memory, engine.engine.device_memory_size)
return max_device_memory
def activate_engines(self, shared_device_memory=None):
if shared_device_memory is None:
max_device_memory = self.max_device_memory()
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
self.shared_device_memory = shared_device_memory
# Load and activate TensorRT engines
for engine in self.engines.values():
engine.activate(reuse_device_memory=self.shared_device_memory)
def run_engine(self, model_name, feed_dict):
return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph)
@@ -0,0 +1,108 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from diffusion_models import PipelineInfo
from engine_builder import EngineBuilder, EngineType
logger = logging.getLogger(__name__)
class TorchEngineBuilder(EngineBuilder):
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
device="cuda",
use_cuda_graph=False,
):
"""
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
device (str):
device to run.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
"""
super().__init__(
EngineType.TORCH,
pipeline_info,
max_batch_size=max_batch_size,
device=device,
use_cuda_graph=use_cuda_graph,
)
self.compile_config = {}
if use_cuda_graph:
self.compile_config = {
"clip": {"mode": "reduce-overhead", "dynamic": False},
"clip2": {"mode": "reduce-overhead", "dynamic": False},
"unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
"unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
"vae": {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False},
}
def build_engines(
self,
framework_model_dir: str,
):
import torch # noqa: PLC0415
self.torch_device = torch.device("cuda", torch.cuda.current_device())
self.load_models(framework_model_dir)
pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None
built_engines = {}
for model_name, model_obj in self.models.items():
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
if self.pipeline_info.is_xl() and not self.custom_fp16_vae:
model = model.to(device=self.torch_device, dtype=torch.float32)
else:
model = model.to(device=self.torch_device, dtype=torch.float16)
if model_name in self.compile_config:
compile_config = self.compile_config[model_name]
if model_name in ["unet", "unetxl"]:
model.to(memory_format=torch.channels_last)
engine = torch.compile(model, **compile_config)
built_engines[model_name] = engine
else: # eager mode
built_engines[model_name] = model
self.engines = built_engines
def run_engine(self, model_name, feed_dict):
if model_name in ["unet", "unetxl"]:
if "controlnet_images" in feed_dict:
return {"latent": self.engines[model_name](**feed_dict)}
if model_name == "unetxl":
added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]}
return {
"latent": self.engines[model_name](
feed_dict["sample"],
feed_dict["timestep"],
feed_dict["encoder_hidden_states"],
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
}
return {
"latent": self.engines[model_name](
feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False
)[0]
}
if model_name in ["vae_encoder"]:
return {"latent": self.engines[model_name](feed_dict["images"])}
raise RuntimeError(f"Shall not reach here: {model_name}")
@@ -0,0 +1,584 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
#
# This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
#
# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint
# to float32 onnx models.
#
# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16
# like the following:
# python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16
#
# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support
# for the fused operators. The users could disable the operator fusion manually to workaround.
import argparse
import logging
import os
import shutil
import tempfile
from pathlib import Path
import coloredlogs
import onnx
from fusion_options import FusionOptions
from onnx_model_clip import ClipOnnxModel
from onnx_model_mmdit import MmditOnnxModel
from onnx_model_t5 import T5OnnxModel
from onnx_model_unet import UnetOnnxModel
from onnx_model_vae import VaeOnnxModel
from optimizer import optimize_by_onnxruntime, optimize_model
from packaging import version
import onnxruntime
logger = logging.getLogger(__name__)
def has_external_data(onnx_model_path):
original_model = onnx.load_model(str(onnx_model_path), load_external_data=False)
for initializer in original_model.graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL:
return True
return False
def is_sd_3(source_dir: Path):
return (source_dir / "text_encoder_3").exists()
def is_sdxl(source_dir: Path):
return (
(source_dir / "text_encoder_2").exists()
and not (source_dir / "text_encoder_3").exists()
and not (source_dir / "transformer").exists()
)
def is_flux(source_dir: Path):
return (
(source_dir / "text_encoder_2").exists()
and not (source_dir / "text_encoder_3").exists()
and (source_dir / "transformer").exists()
)
def _classify_pipeline_type(source_dir: Path):
# May also check _class_name in model_index.json like `StableDiffusion3Pipeline` or `FluxPipeline` etc to classify.
if is_sd_3(source_dir):
return "sd3"
if is_flux(source_dir):
return "flux"
if is_sdxl(source_dir):
return "sdxl"
# sd 1.x and 2.x
return "sd"
def _get_model_list(pipeline_type: str):
if pipeline_type == "sd3":
return ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"]
if pipeline_type == "flux":
return ["text_encoder", "text_encoder_2", "transformer", "vae_encoder", "vae_decoder"]
if pipeline_type == "sdxl":
return ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"]
assert pipeline_type == "sd"
return ["text_encoder", "unet", "vae_encoder", "vae_decoder"]
def _optimize_sd_pipeline(
source_dir: Path,
target_dir: Path,
pipeline_type: str,
model_list: list[str],
use_external_data_format: bool | None,
float16: bool,
bfloat16: bool,
force_fp32_ops: list[str],
enable_runtime_optimization: bool,
args,
):
"""Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
Args:
source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
model_list (List[str]): list of directory names with onnx model.
use_external_data_format (Optional[bool]): use external data format.
float16 (bool): use half precision
bfloat16 (bool): use bfloat16 as fallback if float16 is also provided.
force_fp32_ops(List[str]): operators that are forced to run in float32.
enable_runtime_optimization(bool): run graph optimization using Onnx Runtime.
Raises:
RuntimeError: input onnx model does not exist
RuntimeError: output onnx model path existed
"""
is_flux_pipeline = pipeline_type == "flux"
model_type_mapping = {
"transformer": "mmdit",
"unet": "unet",
"vae_encoder": "vae",
"vae_decoder": "vae",
"text_encoder": "clip",
"text_encoder_2": "t5" if is_flux_pipeline else "clip",
"text_encoder_3": "t5", # t5-v1_1-xxl is used in SD 3.x text_encoder_3 and Flux text_encoder_2.
"safety_checker": "unet",
}
model_type_class_mapping = {
"unet": UnetOnnxModel,
"vae": VaeOnnxModel,
"clip": ClipOnnxModel,
"t5": T5OnnxModel,
"mmdit": MmditOnnxModel,
}
force_fp32_operators = {
"unet": [],
"vae_encoder": [],
"vae_decoder": [],
"text_encoder": [],
"text_encoder_2": [],
"safety_checker": [],
"text_encoder_3": [],
"transformer": [],
}
# The node block list is generated by running the fp32 model and get statistics of node inputs and outputs.
# Nodes with any input or output of float or double data type, but value ouf of range of float16 are candidates.
# python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp32_opt
# export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1
# export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1
# export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1
# python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >stdout.txt 2>stderr.txt
# Warning: The node name might change in different export settings. See benchmark_flux.sh for the settings.
flux_node_block_list = {
"text_encoder_2": [
"/encoder/block.10/layer.1/DenseReluDense/wo/MatMul",
"SkipLayerNorm_20",
"SkipLayerNorm_21",
"SkipLayerNorm_22",
"SkipLayerNorm_23",
"SkipLayerNorm_24",
"SkipLayerNorm_25",
"SkipLayerNorm_26",
"SkipLayerNorm_27",
"SkipLayerNorm_28",
"SkipLayerNorm_29",
"SkipLayerNorm_30",
"SkipLayerNorm_31",
"SkipLayerNorm_32",
"SkipLayerNorm_33",
"SkipLayerNorm_34",
"SkipLayerNorm_35",
"SkipLayerNorm_36",
"SkipLayerNorm_37",
"SkipLayerNorm_38",
"SkipLayerNorm_39",
"SkipLayerNorm_40",
"SkipLayerNorm_41",
"SkipLayerNorm_42",
"SkipLayerNorm_43",
"SkipLayerNorm_44",
"SkipLayerNorm_45",
"/encoder/block.23/layer.1/DenseReluDense/wo/MatMul",
"SkipLayerNorm_46",
],
"vae_decoder": [
"/decoder/mid_block/attentions.0/MatMul",
"/decoder/mid_block/attentions.0/Softmax",
],
"transformer": [
"/transformer_blocks.18/Mul_5",
"/transformer_blocks.18/Add_7",
"/Concat_1",
"LayerNorm_76",
"/single_transformer_blocks.0/Add",
"LayerNorm_77",
"/single_transformer_blocks.1/Add",
"LayerNorm_78",
"/single_transformer_blocks.2/Add",
"LayerNorm_79",
"/single_transformer_blocks.3/Add",
"LayerNorm_80",
"/single_transformer_blocks.4/Add",
"LayerNorm_81",
"/single_transformer_blocks.5/Add",
"LayerNorm_82",
"/single_transformer_blocks.6/Add",
"LayerNorm_83",
"/single_transformer_blocks.7/Add",
"LayerNorm_84",
"/single_transformer_blocks.8/Add",
"LayerNorm_85",
"/single_transformer_blocks.9/Add",
"LayerNorm_86",
"/single_transformer_blocks.10/Add",
"LayerNorm_87",
"/single_transformer_blocks.11/Add",
"LayerNorm_88",
"/single_transformer_blocks.12/Add",
"LayerNorm_89",
"/single_transformer_blocks.13/Add",
"LayerNorm_90",
"/single_transformer_blocks.14/Add",
"LayerNorm_91",
"/single_transformer_blocks.15/Add",
"LayerNorm_92",
"/single_transformer_blocks.16/Add",
"LayerNorm_93",
"/single_transformer_blocks.17/Add",
"LayerNorm_94",
"/single_transformer_blocks.18/Add",
"LayerNorm_95",
"/single_transformer_blocks.19/Add",
"LayerNorm_96",
"/single_transformer_blocks.20/Add",
"LayerNorm_97",
"/single_transformer_blocks.21/Add",
"LayerNorm_98",
"/single_transformer_blocks.22/Add",
"LayerNorm_99",
"/single_transformer_blocks.23/Add",
"LayerNorm_100",
"/single_transformer_blocks.24/Add",
"LayerNorm_101",
"/single_transformer_blocks.25/Add",
"LayerNorm_102",
"/single_transformer_blocks.26/Add",
"LayerNorm_103",
"/single_transformer_blocks.27/Add",
"LayerNorm_104",
"/single_transformer_blocks.28/Add",
"LayerNorm_105",
"/single_transformer_blocks.29/Add",
"LayerNorm_106",
"/single_transformer_blocks.30/Add",
"LayerNorm_107",
"/single_transformer_blocks.31/Add",
"LayerNorm_108",
"/single_transformer_blocks.32/Add",
"LayerNorm_109",
"/single_transformer_blocks.33/Add",
"LayerNorm_110",
"/single_transformer_blocks.34/Add",
"LayerNorm_111",
"/single_transformer_blocks.35/Add",
"LayerNorm_112",
"/single_transformer_blocks.36/Add",
"LayerNorm_113",
"/single_transformer_blocks.37/Add",
"/Shape",
"/Slice",
],
}
sd3_node_block_list = {"text_encoder_3": flux_node_block_list["text_encoder_2"]}
if force_fp32_ops:
for fp32_operator in force_fp32_ops:
parts = fp32_operator.split(":")
if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()):
force_fp32_operators[parts[0]].append(parts[1])
else:
raise ValueError(
f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}"
)
op_counters = {}
for name, model_type in model_type_mapping.items():
onnx_model_path = source_dir / name / "model.onnx"
if not os.path.exists(onnx_model_path):
if name != "safety_checker" and name in model_list:
logger.warning("input onnx model does not exist: %s", onnx_model_path)
# some model are optional so we do not raise error here.
continue
# Prepare output directory
optimized_model_path = target_dir / name / "model.onnx"
if os.path.exists(optimized_model_path):
if not args.overwrite:
logger.warning("Skipped optimization since the target file existed: %s", optimized_model_path)
continue
output_dir = optimized_model_path.parent
output_dir.mkdir(parents=True, exist_ok=True)
if use_external_data_format is None:
use_external_data_format = has_external_data(onnx_model_path)
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
logger.info("Optimize %s ...", onnx_model_path)
args.model_type = model_type
fusion_options = FusionOptions.parse(args)
if model_type in ["unet"]:
# Some optimizations are not available in v1.14 or older version: packed QKV and BiasAdd
has_all_optimizations = version.parse(onnxruntime.__version__) >= version.parse("1.15.0")
fusion_options.enable_packed_kv = float16 and fusion_options.enable_packed_kv
fusion_options.enable_packed_qkv = float16 and has_all_optimizations and fusion_options.enable_packed_qkv
fusion_options.enable_bias_add = has_all_optimizations and fusion_options.enable_bias_add
m = optimize_model(
str(onnx_model_path),
model_type=model_type,
num_heads=0, # will be deduced from graph
hidden_size=0, # will be deduced from graph
opt_level=0,
optimization_options=fusion_options,
use_gpu=True,
provider=args.provider,
)
if float16:
model_node_block_list = (
flux_node_block_list if is_flux_pipeline else sd3_node_block_list if pipeline_type == "sd3" else {}
)
if name in model_node_block_list:
# Opset 12 does not support bfloat16.
# By default, optimum exports T5 model with opset 12. So we need to check the opset version.
use_bfloat16 = bfloat16
if use_bfloat16:
for opset in m.model.opset_import:
if opset.domain in ["", "ai.onnx"] and opset.version < 13:
logger.warning(
"onnx model requires opset 13 or higher to use bfloat16. Fall back to float32."
)
use_bfloat16 = False
m.convert_float_to_float16(
keep_io_types=False,
node_block_list=model_node_block_list[name],
use_bfloat16_as_blocked_nodes_dtype=use_bfloat16,
)
# For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
elif pipeline_type in ["sdxl"] and name in ["vae_decoder"]:
logger.info("Skip converting %s to float16 to avoid NaN", name)
else:
logger.info("Convert %s to float16 ...", name)
m.convert_float_to_float16(
keep_io_types=False,
op_block_list=force_fp32_operators[name],
)
if enable_runtime_optimization:
# Use this step to see the final graph that executed by Onnx Runtime.
with tempfile.TemporaryDirectory() as tmp_dir:
# Save to a temporary file so that we can load it with Onnx Runtime.
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
tmp_model_path = Path(tmp_dir) / "model.onnx"
m.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
optimize_by_onnxruntime(
str(tmp_model_path),
use_gpu=True,
provider=args.provider,
optimized_model_path=str(ort_optimized_model_path),
save_as_external_data=use_external_data_format,
)
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
m = model_type_class_mapping[model_type](model)
m.get_operator_statistics()
op_counters[name] = m.get_fused_operator_statistics()
m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
logger.info("%s is optimized", name)
logger.info("*" * 20)
return op_counters
def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: list[str]):
"""Copy extra directory that does not have onnx model
Args:
source_dir (Path): source directory
target_dir (Path): target directory
model_list (List[str]): list of directory names with onnx model.
Raises:
RuntimeError: source path does not exist
"""
extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"]
for name in extra_dirs:
source_path = source_dir / name
if not os.path.exists(source_path):
continue
target_path = target_dir / name
if target_path.exists():
shutil.rmtree(target_path)
shutil.copytree(source_path, target_path)
logger.info("%s => %s", source_path, target_path)
extra_files = ["model_index.json"]
for name in extra_files:
source_path = source_dir / name
if not os.path.exists(source_path):
raise RuntimeError(f"source path does not exist: {source_path}")
target_path = target_dir / name
shutil.copyfile(source_path, target_path)
logger.info("%s => %s", source_path, target_path)
# Some directory are optional
for onnx_model_dir in model_list:
source_path = source_dir / onnx_model_dir / "config.json"
target_path = target_dir / onnx_model_dir / "config.json"
if source_path.exists():
target_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(source_path, target_path)
logger.info("%s => %s", source_path, target_path)
def optimize_stable_diffusion_pipeline(
input_dir: str,
output_dir: str,
overwrite: bool,
use_external_data_format: bool | None,
float16: bool,
enable_runtime_optimization: bool,
args,
):
if os.path.exists(output_dir):
if overwrite:
shutil.rmtree(output_dir, ignore_errors=True)
source_dir = Path(input_dir)
target_dir = Path(output_dir)
target_dir.mkdir(parents=True, exist_ok=True)
pipeline_type = _classify_pipeline_type(source_dir)
model_list = _get_model_list(pipeline_type)
_copy_extra_directory(source_dir, target_dir, model_list)
return _optimize_sd_pipeline(
source_dir,
target_dir,
pipeline_type,
model_list,
use_external_data_format,
float16,
args.bfloat16,
args.force_fp32_ops,
enable_runtime_optimization,
args,
)
def parse_arguments(argv: list[str] | None = None):
"""Parse arguments
Returns:
Namespace: arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input",
required=True,
type=str,
help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
)
parser.add_argument(
"-o",
"--output",
required=True,
type=str,
help="Root of output directory of stable diffusion onnx pipeline with optimized models.",
)
parser.add_argument(
"--float16",
required=False,
action="store_true",
help="Output models of float16, except some nodes falls back to float32 or bfloat16 to avoid overflow.",
)
parser.set_defaults(float16=False)
parser.add_argument(
"--bfloat16",
required=False,
action="store_true",
help="Allow bfloat16 as fallback if --float16 is also provided.",
)
parser.set_defaults(bfloat16=False)
parser.add_argument(
"--force_fp32_ops",
required=False,
nargs="+",
type=str,
help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!",
)
parser.add_argument(
"--inspect",
required=False,
action="store_true",
help="Save the optimized graph from Onnx Runtime. "
"This option has no impact on inference performance except it might reduce session creation time.",
)
parser.set_defaults(inspect=False)
parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite exists files.",
)
parser.set_defaults(overwrite=False)
parser.add_argument(
"-e",
"--use_external_data_format",
required=False,
action="store_true",
help="Onnx model larger than 2GB need to use external data format. "
"If specified, save each onnx model to two files: one for onnx graph, another for weights. "
"If not specified, use same format as original model by default. ",
)
parser.set_defaults(use_external_data_format=None)
parser.add_argument(
"--provider",
required=False,
type=str,
default=None,
help="Execution provider to use.",
)
FusionOptions.add_arguments(parser)
args = parser.parse_args(argv)
return args
def main(argv: list[str] | None = None):
args = parse_arguments(argv)
logger.info("Arguments: %s", str(args))
# Return op counters for testing purpose.
return optimize_stable_diffusion_pipeline(
args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args
)
if __name__ == "__main__":
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
main()
@@ -0,0 +1,136 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
ONNX Model Optimizer for Stable Diffusion
"""
import gc
import logging
import os
import shutil
import tempfile
from pathlib import Path
import onnx
from packaging import version
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
logger = logging.getLogger(__name__)
class OrtStableDiffusionOptimizer:
def __init__(self, model_type: str):
assert model_type in ["vae", "unet", "clip"]
self.model_type = model_type
self.model_type_class_mapping = {
"unet": UnetOnnxModel,
"vae": VaeOnnxModel,
"clip": ClipOnnxModel,
}
def _optimize_by_ort(self, onnx_model, use_external_data_format, tmp_dir):
# Save to a temporary file so that we can load it with Onnx Runtime.
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
tmp_model_path = Path(tmp_dir) / "model.onnx"
onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
del onnx_model
gc.collect()
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
optimize_by_onnxruntime(
str(tmp_model_path),
use_gpu=True,
optimized_model_path=str(ort_optimized_model_path),
save_as_external_data=use_external_data_format,
external_data_filename="optimized.onnx_data",
)
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
return self.model_type_class_mapping[self.model_type](model)
def optimize_by_ort(self, onnx_model, use_external_data_format=False, tmp_dir=None):
# Use this step to see the final graph that executed by Onnx Runtime.
if tmp_dir is None:
with tempfile.TemporaryDirectory() as temp_dir:
return self._optimize_by_ort(onnx_model, use_external_data_format, temp_dir)
else:
os.makedirs(tmp_dir, exist_ok=True)
model = self._optimize_by_ort(onnx_model, use_external_data_format, tmp_dir)
shutil.rmtree(tmp_dir)
return model
def optimize(
self,
input_fp32_onnx_path,
optimized_onnx_path,
float16=True,
keep_io_types=False,
fp32_op_list=None,
keep_outputs=None,
optimize_by_ort=True,
optimize_by_fusion=True,
final_target_float16=True,
tmp_dir=None,
):
"""Optimize onnx model using ONNX Runtime transformers optimizer"""
logger.info(f"Optimize {input_fp32_onnx_path}...")
if optimize_by_fusion:
fusion_options = FusionOptions(self.model_type)
# It is allowed float16=False and final_target_float16=True, for using fp32 as intermediate optimization step.
# For rare fp32 use case, we can disable packed kv/qkv since there is no fp32 TRT fused attention kernel.
if self.model_type in ["unet"] and not final_target_float16:
fusion_options.enable_packed_kv = False
fusion_options.enable_packed_qkv = False
m = optimize_model(
input_fp32_onnx_path,
model_type=self.model_type,
num_heads=0, # will be deduced from graph
hidden_size=0, # will be deduced from graph
opt_level=0,
optimization_options=fusion_options,
use_gpu=True,
)
else:
model = onnx.load_model(input_fp32_onnx_path, load_external_data=True)
m = self.model_type_class_mapping[self.model_type](model)
if keep_outputs:
m.prune_graph(outputs=keep_outputs)
model_size = m.model.ByteSize()
# model size might be negative (overflow?) in Windows.
use_external_data_format = model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF
# Note that ORT < 1.16 could not save model larger than 2GB.
# This step is is optional since it has no impact on inference latency.
# The optimized model is not portable. It could only run in the same execution provider (CUDA EP in this case).
# When the model has been optimized by onnxruntime, we can disable optimization in SessionOption
# to save session creation time. Another benefit is to inspect the final graph for developing purpose.
from onnxruntime import __version__ as ort_version # noqa: PLC0415
if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format):
m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format, tmp_dir=tmp_dir)
if float16:
logger.info("Convert to float16 ...")
m.convert_float_to_float16(
keep_io_types=keep_io_types,
op_block_list=fp32_op_list,
)
m.get_operator_statistics()
m.get_fused_operator_statistics()
m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)
@@ -0,0 +1,831 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# Modified from TensorRT demo diffusion, which has the following license:
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
import os
import pathlib
import random
import time
from typing import Any
import numpy as np
import nvtx
import torch
from cuda import cudart
from diffusion_models import PipelineInfo, get_tokenizer
from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler
from engine_builder import EngineType
from engine_builder_ort_cuda import OrtCudaEngineBuilder
from engine_builder_ort_trt import OrtTensorrtEngineBuilder
from engine_builder_tensorrt import TensorrtEngineBuilder
from engine_builder_torch import TorchEngineBuilder
from PIL import Image
class StableDiffusionPipeline:
"""
Stable Diffusion pipeline using TensorRT.
"""
def __init__(
self,
pipeline_info: PipelineInfo,
max_batch_size=16,
scheduler="DDIM",
device="cuda",
output_dir=".",
verbose=False,
nvtx_profile=False,
use_cuda_graph=False,
framework_model_dir="pytorch_model",
engine_type: EngineType = EngineType.ORT_CUDA,
):
"""
Initializes the Diffusion pipeline.
Args:
pipeline_info (PipelineInfo):
Version and Type of pipeline.
max_batch_size (int):
Maximum batch size for dynamic batch engine.
scheduler (str):
The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC, LCM].
device (str):
PyTorch device to run inference. Default: 'cuda'
output_dir (str):
Output directory for log files and image artifacts
verbose (bool):
Enable verbose logging.
nvtx_profile (bool):
Insert NVTX profiling markers.
use_cuda_graph (bool):
Use CUDA graph to capture engine execution and then launch inference
framework_model_dir (str):
cache directory for framework checkpoints
engine_type (EngineType)
backend engine type like ORT_TRT or TRT
"""
self.pipeline_info = pipeline_info
self.version = pipeline_info.version
self.vae_scaling_factor = pipeline_info.vae_scaling_factor()
self.max_batch_size = max_batch_size
self.framework_model_dir = framework_model_dir
self.output_dir = output_dir
for directory in [self.framework_model_dir, self.output_dir]:
if not os.path.exists(directory):
print(f"[I] Create directory: {directory}")
pathlib.Path(directory).mkdir(parents=True)
self.device = device
self.torch_device = torch.device(device, torch.cuda.current_device())
self.verbose = verbose
self.nvtx_profile = nvtx_profile
self.use_cuda_graph = use_cuda_graph
self.tokenizer = None
self.tokenizer2 = None
self.generator = torch.Generator(device="cuda")
self.actual_steps = None
self.current_scheduler = None
self.set_scheduler(scheduler)
# backend engine
self.engine_type = engine_type
if engine_type == EngineType.TRT:
self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
elif engine_type == EngineType.ORT_TRT:
self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
elif engine_type == EngineType.ORT_CUDA:
self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
elif engine_type == EngineType.TORCH:
self.backend = TorchEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
else:
raise RuntimeError(f"Backend engine type {engine_type.name} is not supported")
# Load text tokenizer
if not self.pipeline_info.is_xl_refiner():
self.tokenizer = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer")
if self.pipeline_info.is_xl():
self.tokenizer2 = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer_2")
self.control_image_processor = None
if self.pipeline_info.is_xl() and self.pipeline_info.controlnet:
from diffusers.image_processor import VaeImageProcessor # noqa: PLC0415
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=8, do_convert_rgb=True, do_normalize=False
)
# Create CUDA events
self.events = {}
for stage in ["clip", "denoise", "vae", "vae_encoder", "pil"]:
for marker in ["start", "stop"]:
self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
self.markers = {}
def is_backend_tensorrt(self):
return self.engine_type == EngineType.TRT
def set_scheduler(self, scheduler: str):
if scheduler == self.current_scheduler:
return
# Scheduler options
sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012}
if self.version in ("2.0", "2.1"):
sched_opts["prediction_type"] = "v_prediction"
else:
sched_opts["prediction_type"] = "epsilon"
if scheduler == "DDIM":
self.scheduler = DDIMScheduler(device=self.device, **sched_opts)
elif scheduler == "EulerA":
self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts)
elif scheduler == "UniPC":
self.scheduler = UniPCMultistepScheduler(device=self.device, **sched_opts)
elif scheduler == "LCM":
self.scheduler = LCMScheduler(device=self.device, **sched_opts)
else:
raise ValueError("Scheduler should be either DDIM, EulerA, UniPC or LCM")
self.current_scheduler = scheduler
self.denoising_steps = None
def set_denoising_steps(self, denoising_steps: int):
if not (self.denoising_steps == denoising_steps and isinstance(self.scheduler, DDIMScheduler)):
self.scheduler.set_timesteps(denoising_steps)
self.scheduler.configure()
self.denoising_steps = denoising_steps
def load_resources(self, image_height, image_width, batch_size):
# If engine is built with static input shape, call this only once after engine build.
# Otherwise, it need be called before every inference run.
self.backend.load_resources(image_height, image_width, batch_size)
def set_random_seed(self, seed):
if isinstance(seed, int):
self.generator.manual_seed(seed)
else:
self.generator.seed()
def get_current_seed(self):
return self.generator.initial_seed()
def teardown(self):
for e in self.events.values():
cudart.cudaEventDestroy(e)
if self.backend:
self.backend.teardown()
def run_engine(self, model_name, feed_dict):
return self.backend.run_engine(model_name, feed_dict)
def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width):
latents_dtype = torch.float16
latents_shape = (batch_size, unet_channels, latent_height, latent_width)
latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator)
# Scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def initialize_timesteps(self, timesteps, strength):
"""Initialize timesteps for refiner."""
self.scheduler.set_timesteps(timesteps)
offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
init_timestep = int(timesteps * strength) + offset
init_timestep = min(init_timestep, timesteps)
t_start = max(timesteps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
return timesteps, t_start
def initialize_refiner(self, batch_size, image, strength):
"""Add noise to a reference image."""
# Initialize timesteps
timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength)
latent_timestep = timesteps[:1].repeat(batch_size)
# Pre-process input image
image = self.preprocess_images(batch_size, (image,))[0]
# VAE encode init image
if image.shape[1] == 4:
init_latents = image
else:
init_latents = self.encode_image(image)
# Add noise to latents using timesteps
noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float16, generator=self.generator)
latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep)
return timesteps, t_start, latents
def _get_add_time_ids(
self,
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype,
requires_aesthetics_score,
):
if requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
else:
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
return add_time_ids, add_neg_time_ids
def start_profile(self, name, color="blue"):
if self.nvtx_profile:
self.markers[name] = nvtx.start_range(message=name, color=color)
event_name = name + "-start"
if event_name in self.events:
cudart.cudaEventRecord(self.events[event_name], 0)
def stop_profile(self, name):
event_name = name + "-stop"
if event_name in self.events:
cudart.cudaEventRecord(self.events[event_name], 0)
if self.nvtx_profile:
nvtx.end_range(self.markers[name])
def preprocess_images(self, batch_size, images=()):
self.start_profile("preprocess", color="pink")
init_images = []
for i in images:
image = i.to(self.device)
if image.shape[0] != batch_size:
image = image.repeat(batch_size, 1, 1, 1)
init_images.append(image)
self.stop_profile("preprocess")
return tuple(init_images)
def preprocess_controlnet_images(
self, batch_size, images=None, do_classifier_free_guidance=True, height=1024, width=1024
):
"""
Process a list of PIL.Image.Image as control images, and return a torch tensor.
"""
if images is None:
return None
self.start_profile("preprocess", color="pink")
if not self.pipeline_info.is_xl():
images = [
torch.from_numpy(
(np.array(image.convert("RGB")).astype(np.float32) / 255.0)[..., None].transpose(3, 2, 0, 1)
)
.to(device=self.device, dtype=torch.float16)
.repeat_interleave(batch_size, dim=0)
for image in images
]
else:
images = [
self.control_image_processor.preprocess(image, height=height, width=width)
.to(device=self.device, dtype=torch.float16)
.repeat_interleave(batch_size, dim=0)
for image in images
]
if do_classifier_free_guidance:
images = [torch.cat([i] * 2) for i in images]
images = torch.cat([image[None, ...] for image in images], dim=0)
self.stop_profile("preprocess")
return images
def encode_prompt(
self,
prompt,
negative_prompt,
encoder="clip",
tokenizer=None,
pooled_outputs=False,
output_hidden_states=False,
force_zeros_for_empty_prompt=False,
do_classifier_free_guidance=True,
dtype=torch.float16,
):
if tokenizer is None:
tokenizer = self.tokenizer
self.start_profile("clip", color="green")
def tokenize(prompt, output_hidden_states):
text_input_ids = (
tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
.input_ids.type(torch.int32)
.to(self.device)
)
hidden_states = None
if self.engine_type == EngineType.TORCH:
outputs = self.backend.engines[encoder](text_input_ids)
text_embeddings = outputs[0]
if output_hidden_states:
hidden_states = outputs["last_hidden_state"]
else:
outputs = self.run_engine(encoder, {"input_ids": text_input_ids})
text_embeddings = outputs["text_embeddings"]
if output_hidden_states:
hidden_states = outputs["hidden_states"]
return text_embeddings, hidden_states
# Tokenize prompt
text_embeddings, hidden_states = tokenize(prompt, output_hidden_states)
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
text_embeddings = text_embeddings.clone()
if hidden_states is not None:
hidden_states = hidden_states.clone()
# Note: negative prompt embedding is not needed for SD XL when guidance <= 1
if do_classifier_free_guidance:
# For SD XL base, handle force_zeros_for_empty_prompt
is_empty_negative_prompt = all(not i for i in negative_prompt)
if force_zeros_for_empty_prompt and is_empty_negative_prompt:
uncond_embeddings = torch.zeros_like(text_embeddings)
if output_hidden_states:
uncond_hidden_states = torch.zeros_like(hidden_states)
else:
# Tokenize negative prompt
uncond_embeddings, uncond_hidden_states = tokenize(negative_prompt, output_hidden_states)
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
if output_hidden_states:
hidden_states = torch.cat([uncond_hidden_states, hidden_states])
self.stop_profile("clip")
if pooled_outputs:
# For text encoder in sdxl base
return hidden_states.to(dtype=dtype), text_embeddings.to(dtype=dtype)
if output_hidden_states:
# For text encoder 2 in sdxl base or refiner
return hidden_states.to(dtype=dtype)
# For text encoder in sd 1.5
return text_embeddings.to(dtype=dtype)
def denoise_latent(
self,
latents,
text_embeddings,
denoiser="unet",
timesteps=None,
step_offset=0,
guidance=7.5,
add_kwargs=None,
):
do_classifier_free_guidance = guidance > 1.0
self.start_profile("denoise", color="blue")
if not isinstance(timesteps, torch.Tensor):
timesteps = self.scheduler.timesteps
for step_index, timestep in enumerate(timesteps):
# Expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, step_offset + step_index, timestep
)
# Predict the noise residual
if self.nvtx_profile:
nvtx_unet = nvtx.start_range(message="unet", color="blue")
params = {
"sample": latent_model_input,
"timestep": timestep.to(latents.dtype),
"encoder_hidden_states": text_embeddings,
}
if add_kwargs:
params.update(add_kwargs)
noise_pred = self.run_engine(denoiser, params)["latent"]
if self.nvtx_profile:
nvtx.end_range(nvtx_unet)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
if type(self.scheduler) is UniPCMultistepScheduler:
latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
elif type(self.scheduler) is LCMScheduler:
latents = self.scheduler.step(noise_pred, timestep, latents, generator=self.generator)[0]
else:
latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep)
# The actual number of steps. It might be different from denoising_steps.
self.actual_steps = len(timesteps)
self.stop_profile("denoise")
return latents
def encode_image(self, image):
self.start_profile("vae_encoder", color="red")
init_latents = self.run_engine("vae_encoder", {"images": image})["latent"]
init_latents = self.vae_scaling_factor * init_latents
self.stop_profile("vae_encoder")
return init_latents
def decode_latent(self, latents):
self.start_profile("vae", color="red")
images = self.backend.vae_decode(latents)
self.stop_profile("vae")
return images
def print_summary(self, tic, toc, batch_size, vae_enc=False, pil=False) -> dict[str, Any]:
throughput = batch_size / (toc - tic)
latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1]
latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1]
latency_vae = cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1]
latency_vae_encoder = (
cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1]
if vae_enc
else None
)
latency_pil = cudart.cudaEventElapsedTime(self.events["pil-start"], self.events["pil-stop"])[1] if pil else None
latency = (toc - tic) * 1000.0
print("|----------------|--------------|")
print("| {:^14} | {:^12} |".format("Module", "Latency"))
print("|----------------|--------------|")
if vae_enc:
print("| {:^14} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder))
print("| {:^14} | {:>9.2f} ms |".format("CLIP", latency_clip))
print(
"| {:^14} | {:>9.2f} ms |".format(
"UNet" + ("+CNet" if self.pipeline_info.controlnet else "") + " x " + str(self.actual_steps),
latency_unet,
)
)
print("| {:^14} | {:>9.2f} ms |".format("VAE-Dec", latency_vae))
pipeline = "Refiner" if self.pipeline_info.is_xl_refiner() else "Pipeline"
if pil:
print("| {:^14} | {:>9.2f} ms |".format("PIL", latency_pil))
print("|----------------|--------------|")
print(f"| {pipeline:^14} | {latency:>9.2f} ms |")
print("|----------------|--------------|")
print(f"Throughput: {throughput:.2f} image/s")
perf_data = {
"latency_clip": latency_clip,
"latency_unet": latency_unet,
"latency_vae": latency_vae,
"latency_pil": latency_pil,
"latency": latency,
"throughput": throughput,
}
if vae_enc:
perf_data["latency_vae_encoder"] = latency_vae_encoder
return perf_data
@staticmethod
def pt_to_pil(images):
images = (
((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
)
return [Image.fromarray(images[i]) for i in range(images.shape[0])]
@staticmethod
def pt_to_numpy(images: torch.FloatTensor):
"""
Convert a PyTorch tensor to a NumPy image.
"""
return ((images + 1) / 2).clamp(0, 1).detach().permute(0, 2, 3, 1).float().cpu().numpy()
def metadata(self) -> dict[str, Any]:
data = {
"actual_steps": self.actual_steps,
"seed": self.get_current_seed(),
"name": self.pipeline_info.name(),
"custom_vae": self.pipeline_info.custom_fp16_vae(),
"custom_unet": self.pipeline_info.custom_unet(),
}
if self.engine_type == EngineType.ORT_CUDA:
for engine_name, engine in self.backend.engines.items():
data.update(engine.metadata(engine_name))
return data
def save_images(self, images: list, prompt: list[str], negative_prompt: list[str], metadata: dict[str, Any]):
session_id = str(random.randint(1000, 9999))
for i, image in enumerate(images):
seed = str(self.get_current_seed())
prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20]
parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)]
image_path = os.path.join(self.output_dir, "-".join(parts) + ".png")
print(f"Saving image {i + 1} / {len(images)} to: {image_path}")
from PIL import PngImagePlugin # noqa: PLC0415
info = PngImagePlugin.PngInfo()
for k, v in metadata.items():
info.add_text(k, str(v))
info.add_text("prompt", prompt[i])
info.add_text("negative_prompt", negative_prompt[i])
image.save(image_path, "PNG", pnginfo=info)
def _infer(
self,
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=30,
guidance=5.0,
seed=None,
image=None,
strength=0.3,
controlnet_images=None,
controlnet_scales=None,
show_latency=False,
output_type="pil",
):
if show_latency:
torch.cuda.synchronize()
start_time = time.perf_counter()
assert len(prompt) == len(negative_prompt)
batch_size = len(prompt)
self.set_denoising_steps(denoising_steps)
self.set_random_seed(seed)
timesteps = None
step_offset = 0
with torch.inference_mode(), torch.autocast("cuda"):
if image is not None:
timesteps, step_offset, latents = self.initialize_refiner(
batch_size=batch_size,
image=image,
strength=strength,
)
else:
# Pre-initialize latents
latents = self.initialize_latents(
batch_size=batch_size,
unet_channels=4,
latent_height=(image_height // 8),
latent_width=(image_width // 8),
)
do_classifier_free_guidance = guidance > 1.0
if not self.pipeline_info.is_xl():
denoiser = "unet"
text_embeddings = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
dtype=latents.dtype,
)
add_kwargs = {}
else:
denoiser = "unetxl"
# Time embeddings
original_size = (image_height, image_width)
crops_coords_top_left = (0, 0)
target_size = (image_height, image_width)
aesthetic_score = 6.0
negative_aesthetic_score = 2.5
add_time_ids, add_negative_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype=latents.dtype,
requires_aesthetics_score=self.pipeline_info.is_xl_refiner(),
)
if do_classifier_free_guidance:
add_time_ids = torch.cat([add_negative_time_ids, add_time_ids], dim=0)
add_time_ids = add_time_ids.to(device=self.device).repeat(batch_size, 1)
if self.pipeline_info.is_xl_refiner():
# CLIP text encoder 2
text_embeddings, pooled_embeddings2 = self.encode_prompt(
prompt,
negative_prompt,
encoder="clip2",
tokenizer=self.tokenizer2,
pooled_outputs=True,
output_hidden_states=True,
dtype=latents.dtype,
)
add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
else: # XL Base
# CLIP text encoder
text_embeddings = self.encode_prompt(
prompt,
negative_prompt,
encoder="clip",
tokenizer=self.tokenizer,
output_hidden_states=True,
force_zeros_for_empty_prompt=True,
do_classifier_free_guidance=do_classifier_free_guidance,
dtype=latents.dtype,
)
# CLIP text encoder 2
text_embeddings2, pooled_embeddings2 = self.encode_prompt(
prompt,
negative_prompt,
encoder="clip2",
tokenizer=self.tokenizer2,
pooled_outputs=True,
output_hidden_states=True,
force_zeros_for_empty_prompt=True,
do_classifier_free_guidance=do_classifier_free_guidance,
dtype=latents.dtype,
)
# Merged text embeddings
text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1)
add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
if self.pipeline_info.controlnet:
controlnet_images = self.preprocess_controlnet_images(
latents.shape[0],
controlnet_images,
do_classifier_free_guidance=do_classifier_free_guidance,
height=image_height,
width=image_width,
)
add_kwargs.update(
{
"controlnet_images": controlnet_images,
"controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device),
}
)
# UNet denoiser
latents = self.denoise_latent(
latents,
text_embeddings,
timesteps=timesteps,
step_offset=step_offset,
denoiser=denoiser,
guidance=guidance,
add_kwargs=add_kwargs,
)
with torch.inference_mode():
# VAE decode latent
if output_type == "latent":
images = latents
else:
images = self.decode_latent(latents / self.vae_scaling_factor)
if output_type == "pil":
self.start_profile("pil", color="green")
images = self.pt_to_pil(images)
self.stop_profile("pil")
perf_data = None
if show_latency:
torch.cuda.synchronize()
end_time = time.perf_counter()
perf_data = self.print_summary(
start_time, end_time, batch_size, vae_enc=self.pipeline_info.is_xl_refiner(), pil=(output_type == "pil")
)
return images, perf_data
def run(
self,
prompt: list[str],
negative_prompt: list[str],
image_height: int,
image_width: int,
denoising_steps: int = 30,
guidance: float = 5.0,
seed: int | None = None,
image: torch.Tensor | None = None,
strength: float = 0.3,
controlnet_images: torch.Tensor | None = None,
controlnet_scales: torch.Tensor | None = None,
show_latency: bool = False,
output_type: str = "pil",
deterministic: bool = False,
):
"""
Run the diffusion pipeline.
Args:
prompt (List[str]):
The text prompt to guide image generation.
negative_prompt (List[str]):
The prompt not to guide the image generation.
image_height (int):
Height (in pixels) of the image to be generated. Must be a multiple of 8.
image_width (int):
Width (in pixels) of the image to be generated. Must be a multiple of 8.
denoising_steps (int):
Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference.
guidance (float):
Higher guidance scale encourages to generate images that are closely linked to the text prompt.
seed (int):
Seed for the random generator
image (tuple[torch.Tensor]):
Reference image.
strength (float):
Indicates extent to transform the reference image, which is used as a starting point,
and more noise is added the higher the strength.
show_latency (bool):
Whether return latency data.
output_type (str):
It can be "latent", "pt" or "pil".
"""
if deterministic:
torch.use_deterministic_algorithms(True)
if self.is_backend_tensorrt():
import tensorrt as trt # noqa: PLC0415
from trt_utilities import TRT_LOGGER # noqa: PLC0415
with trt.Runtime(TRT_LOGGER):
return self._infer(
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=denoising_steps,
guidance=guidance,
seed=seed,
image=image,
strength=strength,
controlnet_images=controlnet_images,
controlnet_scales=controlnet_scales,
show_latency=show_latency,
output_type=output_type,
)
else:
return self._infer(
prompt,
negative_prompt,
image_height,
image_width,
denoising_steps=denoising_steps,
guidance=guidance,
seed=seed,
image=image,
strength=strength,
controlnet_images=controlnet_images,
controlnet_scales=controlnet_scales,
show_latency=show_latency,
output_type=output_type,
)
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
def init_trt_plugins():
# Register TensorRT plugins
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,318 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import copy
import logging
import os
import torch
from benchmark_helper import (
Precision,
create_onnxruntime_session,
prepare_environment,
setup_logger,
)
from onnx.shape_inference import infer_shapes_path
from t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS, T5Helper
from transformers import MT5Config, T5Config
logger = logging.getLogger("")
def parse_arguments():
parser = argparse.ArgumentParser()
pretrained_models = PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS
parser.add_argument(
"-m",
"--model_name_or_path",
required=False,
default=PRETRAINED_T5_MODELS[0],
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models),
)
parser.add_argument(
"--model_type",
required=False,
type=str,
default="t5",
choices=["t5", "mt5"],
help="Model type: either t5 (default) or mt5",
)
parser.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
parser.add_argument(
"--output",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Output directory",
)
parser.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
parser.set_defaults(optimize_onnx=False)
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument(
"-p",
"--precision",
required=False,
type=str,
default=Precision.FLOAT32.value,
choices=[Precision.FLOAT32.value, Precision.FLOAT16.value],
help="Precision of model to run. fp32 for full precision, fp16 for half precision",
)
parser.add_argument("--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
parser.set_defaults(use_external_data_format=False)
parser.add_argument(
"-s",
"--use_decoder_start_token",
required=False,
action="store_true",
help="Use config.decoder_start_token_id. Otherwise, add an extra graph input for decoder_input_ids.",
)
parser.set_defaults(use_decoder_start_token=False)
parser.add_argument(
"-w",
"--overwrite",
required=False,
action="store_true",
help="overwrite existing ONNX model",
)
parser.set_defaults(overwrite=False)
parser.add_argument(
"--disable_auto_mixed_precision",
required=False,
action="store_true",
help="do not use auto mixed precision conversion",
)
parser.set_defaults(disable_auto_mixed_precision=False)
parser.add_argument(
"--force_fp16_io",
required=False,
action="store_true",
help="Force to convert all float inputs and outputs to fp16 when precision is fp16.",
)
parser.set_defaults(force_fp16_io=False)
parser.add_argument(
"--use_int64_inputs",
required=False,
action="store_true",
help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.",
)
parser.set_defaults(use_int64_inputs=False)
parser.add_argument(
"--state_dict_path",
type=str,
default="",
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)
parser.add_argument(
"--encoder_decoder_init",
required=False,
action="store_true",
help="Combine encoder and decoder kv cache initialization into one model. It is legacy format that will be deprecated.",
)
parser.set_defaults(encoder_decoder_init=False)
args = parser.parse_args()
return args
def export_onnx_models(
model_name_or_path: str,
cache_dir: str,
output_dir: str,
use_gpu: bool = False,
use_external_data_format: bool = False,
optimize_onnx: bool = False,
precision: str = Precision.FLOAT32.value,
verbose: bool = False,
use_decoder_start_token: bool = False,
overwrite: bool = False,
disable_auto_mixed_precision: bool = False,
use_int32_inputs: bool = True,
model_type: str = "t5",
state_dict_path: str = "",
encoder_decoder_init: bool = False,
force_fp16_io: bool = False,
shape_infer_before_optimization: bool = False,
):
assert precision in [Precision.FLOAT32.value, Precision.FLOAT16.value], (
f"Invalid precision: {precision}. Use 'fp32' or 'fp16'."
)
device = torch.device("cuda:0" if use_gpu else "cpu")
models = T5Helper.load_model(
model_name_or_path,
cache_dir,
device,
model_type,
state_dict_path,
encoder_decoder_init=encoder_decoder_init,
)
config: T5Config | MT5Config = models["decoder"].config
if (not use_external_data_format) and (config.num_layers > 24):
logger.info("Try use_external_data_format when model size > 2GB")
output_paths = []
for name, model in models.items():
model.to(device)
filename_suffix = "_" + name
onnx_path = T5Helper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix,
new_folder=False,
)
if overwrite or not os.path.exists(onnx_path):
logger.info(f"Exporting ONNX model to {onnx_path}")
# We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
cloned_model = copy.deepcopy(model).to(device)
T5Helper.export_onnx(
cloned_model,
device,
onnx_path,
verbose,
use_external_data_format,
use_decoder_input_ids=not use_decoder_start_token,
use_int32_inputs=use_int32_inputs,
)
else:
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
# Optimize ONNX graph.
# The precision shall be compared with string value. It is because the Precision enum loaded from local file
# (like by transformers test in CI pipeline) are not same as Precision enum from package.
if optimize_onnx or precision != Precision.FLOAT32.value:
onnx_shape_path = None
if shape_infer_before_optimization:
onnx_shape_path = T5Helper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix + "_shape",
new_folder=False,
)
infer_shapes_path(onnx_path, onnx_shape_path)
output_path = T5Helper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix + "_" + str(precision),
new_folder=False,
)
if overwrite or not os.path.exists(output_path):
logger.info(f"Optimizing model to {output_path}")
T5Helper.optimize_onnx(
onnx_shape_path or onnx_path,
output_path,
precision == Precision.FLOAT16.value,
config.num_heads,
config.hidden_size,
use_external_data_format,
auto_mixed_precision=not disable_auto_mixed_precision,
use_gpu=use_gpu,
force_fp16_io=force_fp16_io,
)
else:
logger.info(f"Skip optimizing: existed ONNX model {output_path}")
else:
output_path = onnx_path
ort_session = create_onnxruntime_session(
output_path,
use_gpu=use_gpu,
verbose=verbose,
)
if ort_session is None:
break
with torch.no_grad():
max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs)
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
# The threshold cannot apply to fp16 model, which need a larger threshold.
if precision == Precision.FLOAT32.value and max_diff > 1e-4:
logger.warning("PyTorch and OnnxRuntime results are NOT close")
output_paths.append(output_path)
return output_paths
def main():
args = parse_arguments()
setup_logger(args.verbose)
logger.info(f"Arguments:{args}")
cache_dir = args.cache_dir
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
prepare_environment(cache_dir, output_dir, args.use_gpu)
if args.precision != Precision.FLOAT32.value:
assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
if args.precision == Precision.FLOAT16.value:
assert args.use_gpu, "fp16 requires --use_gpu"
output_paths = export_onnx_models(
args.model_name_or_path,
cache_dir,
output_dir,
args.use_gpu,
args.use_external_data_format,
args.optimize_onnx,
args.precision,
args.verbose,
args.use_decoder_start_token,
args.overwrite,
args.disable_auto_mixed_precision,
not args.use_int64_inputs,
args.model_type,
encoder_decoder_init=args.encoder_decoder_init,
force_fp16_io=args.force_fp16_io,
)
logger.info(f"Done! Outputs: {output_paths}")
if __name__ == "__main__":
main()
@@ -0,0 +1,437 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from pathlib import Path
import numpy
import onnx
import torch
from io_binding_helper import TypeHelper
from onnx_model import OnnxModel
from past_helper import PastKeyValuesHelper
from t5_encoder import T5EncoderInputs
from torch_onnx_export_helper import torch_onnx_export
from transformers import MT5Config, T5Config
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class T5DecoderInit(torch.nn.Module):
"""A T5 decoder with LM head to create initial past key values.
This model is only called once during starting decoding.
"""
def __init__(
self,
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
config: T5Config | MT5Config,
decoder_start_token_id: int | None = None,
):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
self.decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)
self.tie_word_embeddings = (
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
)
def forward(
self,
decoder_input_ids: torch.Tensor,
encoder_attention_mask: torch.Tensor,
encoder_hidden_states: torch.FloatTensor,
):
if decoder_input_ids is None:
batch_size = encoder_attention_mask.shape[0]
decoder_input_ids = (
torch.ones(
(batch_size, 1),
dtype=torch.long,
device=encoder_attention_mask.device,
)
* self.decoder_start_token_id
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
return_dict=True,
)
sequence_output = decoder_outputs.last_hidden_state
present_key_values = decoder_outputs.past_key_values
if self.tie_word_embeddings:
sequence_output = sequence_output * (self.config.d_model**-0.5)
lm_logits = self.lm_head(sequence_output)
past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
return lm_logits, past_self, past_cross
class T5Decoder(torch.nn.Module):
"""A T5 decoder with LM head and past key values"""
def __init__(self, decoder, lm_head, config):
super().__init__()
self.decoder = decoder
self.lm_head = lm_head
self.config = config
self.tie_word_embeddings = (
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
)
def forward(self, decoder_input_ids, encoder_attention_mask, *past):
num_decoder_layers = self.config.num_decoder_layers
past_key_values = PastKeyValuesHelper.group_by_layer(past, num_decoder_layers)
# This is a hack since only the third dimension of encoder_hidden_states is used here
dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
past_key_values=past_key_values,
encoder_hidden_states=dummy_encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
return_dict=True,
)
sequence_output = decoder_outputs.last_hidden_state
present_key_values = decoder_outputs.past_key_values
if self.tie_word_embeddings:
sequence_output = sequence_output * (self.config.d_model**-0.5)
lm_logits = self.lm_head(sequence_output)
present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
# Do not return present_cross since they are identical to corresponding past_cross input
return lm_logits, present_self
class T5DecoderInputs:
def __init__(
self,
decoder_input_ids,
encoder_attention_mask,
past_key_values=None,
):
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
self.past_key_values: list[torch.FloatTensor] | list[torch.HalfTensor] | None = past_key_values
@staticmethod
def create_dummy(
config: T5Config | MT5Config,
batch_size: int,
encode_sequence_length: int,
past_decode_sequence_length: int,
device: torch.device,
float16: bool = False,
use_int32_inputs: bool = False,
): # -> T5DecoderInputs:
"""Create dummy inputs for T5Decoder.
Args:
decoder: decoder
batch_size (int): batch size
encode_sequence_length (int): sequence length of input_ids for encoder
past_decode_sequence_length (int): past sequence length of input_ids for decoder
device (torch.device): device of output tensors
float16 (bool): whether the model uses float32 or float16 in input
use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
Returns:
T5DecoderInputs: dummy inputs for decoder
"""
num_attention_heads: int = config.num_heads
num_layers: int = config.num_decoder_layers
vocab_size: int = config.vocab_size
# Do not use head_size = hidden_size / num_attention_heads here.
# For example, mt5-small, d_model=512 and num_heads=6
head_size: int = config.d_kv
sequence_length: int = 1 # fixed for decoding
decoder_input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=(torch.int32 if use_int32_inputs else torch.int64),
device=device,
)
encoder_inputs = T5EncoderInputs.create_dummy(
batch_size,
encode_sequence_length,
vocab_size,
device,
use_int32_inputs=use_int32_inputs,
)
float_type = torch.float16 if float16 else torch.float32
if past_decode_sequence_length > 0:
self_attention_past_shape = [
batch_size,
num_attention_heads,
past_decode_sequence_length,
head_size,
]
cross_attention_past_shape = [
batch_size,
num_attention_heads,
encode_sequence_length,
head_size,
]
past = []
for _ in range(2 * num_layers):
past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
for _ in range(2 * num_layers):
past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
else:
past = None
return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past)
def to_list(self) -> list:
input_list = [
self.decoder_input_ids,
self.encoder_attention_mask,
]
if self.past_key_values:
input_list.extend(self.past_key_values)
return input_list
def to_fp32(self):
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
return T5DecoderInputs(
self.decoder_input_ids.clone(),
self.encoder_attention_mask.clone(),
past,
)
class T5DecoderHelper:
@staticmethod
def export_onnx(
decoder: T5Decoder | T5DecoderInit,
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export decoder to ONNX
Args:
decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
device (torch.device): device of decoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_int32_inputs (bool, optional): use int32 inputs
"""
assert isinstance(decoder, (T5Decoder, T5DecoderInit))
inputs = T5DecoderInputs.create_dummy(
decoder.config,
batch_size=2,
encode_sequence_length=3,
past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
num_decoder_layers = decoder.config.num_decoder_layers
past_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=True)
present_self_names = present_names[: 2 * num_decoder_layers]
input_past_names = past_names if isinstance(decoder, T5Decoder) else []
output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
output_names = ["logits", *output_present_names]
# Shape of input tensors (sequence_length==1):
# input_ids: (batch_size, sequence_length)
# encoder_attention_mask: (batch_size, encode_sequence_length)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
# Shape of output tensors:
# logits: (batch_size, sequence_length, vocab_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
input_names = ["input_ids"]
input_names.append("encoder_attention_mask")
input_names.extend(input_past_names)
dynamic_axes = {
"input_ids": {
0: "batch_size",
# 1: 'sequence_length'
},
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
"encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
"logits": {
0: "batch_size",
# 1: 'sequence_length'
},
}
for name in input_past_names:
dynamic_axes[name] = {
0: "batch_size",
2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
}
for name in output_present_names:
if "cross" in name:
dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
else: # self attention past state
if isinstance(decoder, T5Decoder):
dynamic_axes[name] = {
0: "batch_size",
2: "past_decode_sequence_length + 1",
}
else:
dynamic_axes[name] = {
0: "batch_size",
# 2: 'sequence_length'
}
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
decoder,
args=tuple(input_list),
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=12,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
if use_external_data_format:
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: T5DecoderInputs):
"""Run inference of ONNX model."""
logger.debug("start onnxruntime_inference")
ort_inputs = {
"input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
}
if inputs.past_key_values:
assert len(inputs.past_key_values) % 4 == 0
num_layers = int(len(inputs.past_key_values) / 4)
past_names = PastKeyValuesHelper.get_past_names(num_layers)
for i, past_tensor in enumerate(inputs.past_key_values):
ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs
@staticmethod
def verify_onnx(
model: T5Decoder | T5DecoderInit,
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
test_cases_max_diff = []
for (
batch_size,
encode_sequence_length,
past_decode_sequence_length,
) in test_cases[:max_cases]:
if isinstance(model, T5DecoderInit):
past_decode_sequence_length = 0 # noqa: PLW2901
inputs = T5DecoderInputs.create_dummy(
model.config,
batch_size,
encode_sequence_length,
past_decode_sequence_length,
device=device,
float16=float16,
use_int32_inputs=use_int32_inputs,
)
# We use fp32 PyTroch model as baseline even when ONNX model is fp16
input_list = inputs.to_fp32().to_list()
# Run inference of PyTorch model
with torch.no_grad():
torch_outputs = model(*input_list)
ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
num_decoder_layers = model.config.num_decoder_layers
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
max_diff_all = max_diff
logger.debug(f"logits max_diff={max_diff}")
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
logger.debug(f"self attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
if isinstance(model, T5DecoderInit):
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(
numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * num_decoder_layers + i])
)
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
test_cases_max_diff.append(max_diff_all)
logger.info(
"batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
batch_size,
encode_sequence_length,
past_decode_sequence_length,
max_diff_all,
)
return max_diff_all
@@ -0,0 +1,70 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -------------------------------------------------------------------------
import logging
import random
import torch
from transformers import MT5Config, T5Config
logger = logging.getLogger(__name__)
class T5Encoder(torch.nn.Module):
"""T5 encoder outputs only the last hidden state"""
def __init__(self, encoder, config: T5Config | MT5Config):
super().__init__()
self.encoder = encoder
self.config = config
def forward(self, input_ids, attention_mask):
return self.encoder(input_ids, attention_mask)[0]
class T5EncoderInputs:
def __init__(self, input_ids, attention_mask):
self.input_ids: torch.LongTensor = input_ids
self.attention_mask: torch.LongTensor = attention_mask
@staticmethod
def create_dummy(
batch_size: int,
sequence_length: int,
vocab_size: int,
device: torch.device,
use_int32_inputs: bool = False,
): # -> T5EncoderInputs
"""Create dummy inputs for T5 encoder.
Args:
batch_size (int): batch size
sequence_length (int): sequence length
vocab_size (int): vocabulary size
device (torch.device): device of output tensors
Returns:
T5EncoderInputs: dummy inputs for encoder
"""
dtype = torch.int32 if use_int32_inputs else torch.int64
input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=dtype,
device=device,
)
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
if sequence_length >= 2:
for i in range(batch_size):
padding_position = random.randint(0, sequence_length - 1)
attention_mask[i, :padding_position] = 0
return T5EncoderInputs(input_ids, attention_mask)
def to_list(self) -> list:
input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
return input_list
@@ -0,0 +1,361 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -------------------------------------------------------------------------
import logging
import os
import tempfile
from pathlib import Path
import numpy
import onnx
import torch
from onnx_model import OnnxModel
from past_helper import PastKeyValuesHelper
from t5_decoder import T5DecoderInit
from t5_encoder import T5Encoder, T5EncoderInputs
from torch_onnx_export_helper import torch_onnx_export
from transformers import MT5Config, T5Config
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class T5EncoderDecoderInit(torch.nn.Module):
"""A combination of T5Encoder and T5DecoderInit."""
def __init__(
self,
encoder: torch.nn.Module,
decoder: torch.nn.Module,
lm_head: torch.nn.Linear,
config: T5Config | MT5Config,
decoder_start_token_id: int | None = None,
output_cross_only: bool = False,
):
super().__init__()
self.config: T5Config | MT5Config = config
self.t5_encoder = T5Encoder(encoder, config)
self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id)
self.output_cross_only = output_cross_only
def forward(
self,
encoder_input_ids: torch.Tensor,
encoder_attention_mask: torch.Tensor,
decoder_input_ids: torch.Tensor | None = None,
):
encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask)
lm_logits, past_self, past_cross = self.t5_decoder_init(
decoder_input_ids, encoder_attention_mask, encoder_hidden_states
)
if self.output_cross_only:
return past_cross
else:
return lm_logits, encoder_hidden_states, past_self, past_cross
class T5EncoderDecoderInitInputs:
def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
self.encoder_input_ids: torch.LongTensor = encoder_input_ids
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
self.decoder_input_ids: torch.LongTensor | None = decoder_input_ids
@staticmethod
def create_dummy(
config: T5Config | MT5Config,
batch_size: int,
encode_sequence_length: int,
use_decoder_input_ids: int,
device: torch.device,
use_int32_inputs: bool = False,
): # -> T5EncoderDecoderInitInputs:
encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
batch_size,
encode_sequence_length,
config.vocab_size,
device,
use_int32_inputs=use_int32_inputs,
)
decoder_input_ids = None
if use_decoder_input_ids:
dtype = torch.int32 if use_int32_inputs else torch.int64
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
def to_list(self) -> list:
input_list = [self.encoder_input_ids, self.encoder_attention_mask]
if self.decoder_input_ids is not None:
input_list.append(self.decoder_input_ids)
return input_list
class T5EncoderDecoderInitHelper:
@staticmethod
def export_onnx(
model: T5EncoderDecoderInit,
device: torch.device,
onnx_model_path: str,
use_decoder_input_ids: bool = True,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export decoder to ONNX
Args:
model (T5EncoderDecoderInit): the model to export
device (torch.device): device of decoder object
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_int32_inputs (bool, optional): use int32 instead of int64 for integer inputs. Defaults to False.
"""
assert isinstance(model, T5EncoderDecoderInit)
# Do not exclude decoder in torch onnx export so that cross can show up.
output_cross_only = model.output_cross_only
model.output_cross_only = False
inputs = T5EncoderDecoderInitInputs.create_dummy(
model.config,
batch_size=2,
encode_sequence_length=3,
use_decoder_input_ids=use_decoder_input_ids,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)
output_names = ["logits", "encoder_hidden_states", *present_names]
# Shape of input tensors (sequence_length==1):
# input_ids: (batch_size, sequence_length)
# encoder_attention_mask: (batch_size, encode_sequence_length)
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
# Shape of output tensors:
# logits: (batch_size, sequence_length, vocab_size)
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
input_names = ["encoder_input_ids", "encoder_attention_mask"]
# ONNX exporter might mark dimension like 'present_value_self_1_dim_2' in shape inference.
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
sequence_length = "1"
num_heads = str(model.config.num_heads)
hidden_size = str(model.config.d_model)
head_size = str(model.config.d_kv)
dynamic_axes = {
"encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"},
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
"encoder_hidden_states": {
0: "batch_size",
1: "encode_sequence_length",
2: hidden_size,
},
"logits": {
0: "batch_size",
1: sequence_length,
},
}
if use_decoder_input_ids:
input_names.append("decoder_input_ids")
dynamic_axes["decoder_input_ids"] = {
0: "batch_size",
1: sequence_length,
}
for name in present_names:
if "cross" in name:
dynamic_axes[name] = {
0: "batch_size",
1: num_heads,
2: "encode_sequence_length",
3: head_size,
}
else: # self attention past state
dynamic_axes[name] = {
0: "batch_size",
1: num_heads,
2: sequence_length,
3: head_size,
}
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
model,
args=tuple(input_list),
f=temp_onnx_model_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=12,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
# Restore output_cross_only setting.
model.output_cross_only = output_cross_only
# Workaround as mentioned earlier: change numeric dim_param to dim_value
exported_model: onnx.ModelProto = onnx.load(temp_onnx_model_path)
for tensor in exported_model.graph.output:
for dim_proto in tensor.type.tensor_type.shape.dim:
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
sequence_length,
num_heads,
hidden_size,
head_size,
]:
dim_value = int(dim_proto.dim_param)
dim_proto.Clear()
dim_proto.dim_value = dim_value
if output_cross_only:
# Rewrite onnx graph to only keep present_[key|value]_cross_* outputs.
onnx_model = OnnxModel(exported_model)
output_name_to_node = onnx_model.output_name_to_node()
for output in exported_model.graph.output:
if "cross" in output.name:
assert output.name in output_name_to_node
transpose_node = output_name_to_node[output.name]
assert transpose_node and transpose_node.op_type == "Transpose"
permutation = OnnxModel.get_node_attribute(transpose_node, "perm")
assert isinstance(permutation, list)
assert permutation == [0, 2, 1, 3]
matched_nodes = onnx_model.match_parent_path(
transpose_node,
["Reshape", "MatMul"],
[0, 0],
output_name_to_node,
)
assert matched_nodes is not None
reshape_node, matmul_node = matched_nodes
assert "encoder_hidden_states" in matmul_node.input
if not onnx_model.get_initializer("cross_reshape_shape"):
shape_tensor = onnx.helper.make_tensor(
name="cross_reshape_shape",
data_type=onnx.TensorProto.INT64,
dims=[4],
vals=[0, 0, int(num_heads), int(head_size)],
raw=False,
)
onnx_model.add_initializer(shape_tensor)
reshape_node.input[1] = "cross_reshape_shape"
cross_outputs = [output.name for output in exported_model.graph.output if "cross" in output.name]
onnx_model.prune_graph(cross_outputs, allow_remove_graph_inputs=True)
OnnxModel.save(
exported_model,
onnx_model_path,
save_as_external_data=use_external_data_format,
all_tensors_to_one_file=True,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
"""Run inference of ONNX model."""
logger.debug("start onnxruntime_inference")
ort_inputs = {
"encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
}
if inputs.decoder_input_ids is not None:
ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
ort_outputs = ort_session.run(None, ort_inputs)
return ort_outputs
@staticmethod
def verify_onnx(
model: T5EncoderDecoderInit,
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
ort_inputs = ort_session.get_inputs()
use_decoder_input_ids = len(ort_inputs) == 3
test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)]
test_cases_max_diff = []
for batch_size, encode_sequence_length in test_cases[:max_cases]:
inputs = T5EncoderDecoderInitInputs.create_dummy(
model.config,
batch_size,
encode_sequence_length,
use_decoder_input_ids=use_decoder_input_ids,
device=device,
use_int32_inputs=use_int32_inputs,
)
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
# Run inference of PyTorch model
input_list = inputs.to_list()
torch_outputs = model(*input_list)
num_decoder_layers = model.config.num_decoder_layers
if not model.output_cross_only:
assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
logger.debug(f"logits max_diff={max_diff}")
max_diff_all = max_diff
assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
logger.debug(f"self attention past state {i} max_diff={max_diff}")
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
)
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
else:
max_diff_all = -float("inf")
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[i].cpu().numpy() - ort_outputs[i]))
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
test_cases_max_diff.append(max_diff_all)
logger.info(
f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"
)
return max(test_cases_max_diff)
@@ -0,0 +1,302 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -------------------------------------------------------------------------
import logging
import os
from pathlib import Path
import torch
from float16 import float_to_float16_max_diff
from onnx_model import OnnxModel
from optimizer import optimize_model
from t5_decoder import T5Decoder, T5DecoderHelper
from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
PRETRAINED_MT5_MODELS = [
"google/mt5-small",
"google/mt5-base",
"google/mt5-large",
"google/mt5-xl",
"google/mt5-xxl",
]
class T5Helper:
@staticmethod
def get_onnx_path(
output_dir: str,
model_name_or_path: str,
suffix: str = "",
new_folder: bool = False,
) -> str:
"""Build onnx path
Args:
output_dir (str): output directory
model_name_or_path (str): pretrained model name, or path to the model checkpoint
suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
new_folder (bool, optional): create a new directory for the model. Defaults to False.
Returns:
str: path of onnx model
"""
model_name = model_name_or_path
if os.path.isdir(model_name_or_path):
model_name = Path(model_name_or_path).parts[-1]
else:
model_name.split("/")[-1]
model_name += suffix
directory = os.path.join(output_dir, model_name) if new_folder else output_dir
return os.path.join(directory, model_name + ".onnx")
@staticmethod
def load_model(
model_name_or_path: str,
cache_dir: str,
device: torch.device,
model_type: str = "t5",
state_dict_path: str = "",
encoder_decoder_init: bool = False,
) -> dict[str, T5EncoderDecoderInit | T5Decoder]:
"""Load model given a pretrained name or path, then build models for ONNX conversion.
Args:
model_name_or_path (str): pretrained model name or path
cache_dir (str): cache directory
device (torch.device): device to run the model
model_type (str, optional): model type "t5" or "mt5"
state_dict_path(str, optional): state dictionary path
encoder_decoder_init (bool, optional): combine encoder and decoder kv cache initialization into one model.
Returns:
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
"""
if model_type == "t5":
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
elif model_type == "mt5":
model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
else:
raise ValueError("only support mode_type=t5 or mt5")
if state_dict_path:
model.load_state_dict(torch.load(state_dict_path))
decoder = T5Decoder(model.decoder, model.lm_head, model.config)
decoder.eval().to(device)
encoder = T5EncoderDecoderInit(
model.encoder,
model.decoder,
model.lm_head,
model.config,
decoder_start_token_id=None,
output_cross_only=not encoder_decoder_init,
)
encoder_name = "encoder_decoder_init" if encoder_decoder_init else "encoder"
return {encoder_name: encoder, "decoder": decoder}
@staticmethod
def export_onnx(
model: T5Decoder | T5EncoderDecoderInit,
device: torch.device,
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_decoder_input_ids: bool = True,
use_int32_inputs: bool = False,
):
if isinstance(model, T5EncoderDecoderInit):
T5EncoderDecoderInitHelper.export_onnx(
model,
device,
onnx_model_path,
use_decoder_input_ids,
verbose,
use_external_data_format,
use_int32_inputs,
)
else:
T5DecoderHelper.export_onnx(
model,
device,
onnx_model_path,
verbose,
use_external_data_format,
use_int32_inputs,
)
@staticmethod
def auto_mixed_precision(
onnx_model: OnnxModel,
op_block_list: list[str] | None = None,
force_fp16_logits: bool = False,
use_symbolic_shape_infer: bool = True,
):
"""Convert model to mixed precision.
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
Args:
onnx_model (OnnxModel): optimized ONNX model
op_block_list (List[str], optional): operators need to run in fp32.
force_fp16_logits (bool, optional): force logits and last MatMul node to be in float16. Defaults to False.
use_symbolic_shape_infer (bool, optional): use symbolic shape inference to convert float to float16. Defaults to True.
Returns:
parameters(dict): a dictionary of parameters used in float16 conversion
"""
if op_block_list is None:
op_block_list = [
"SimplifiedLayerNormalization",
"SkipSimplifiedLayerNormalization",
"Relu",
"Add",
]
op_full_set = {node.op_type for node in onnx_model.nodes()}
fp32_op_set = set(op_block_list)
fp16_op_set = op_full_set.difference(fp32_op_set)
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
# logits is the first output
logits_output_name = onnx_model.graph().output[0].name
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
is_weight_fp16_precision = False
output_name_to_node = onnx_model.output_name_to_node()
assert logits_output_name in output_name_to_node
node = output_name_to_node[logits_output_name]
last_matmul_node = None
if node.op_type == "MatMul":
last_matmul_node = node
logger.info(f"Found last MatMul node for logits: {node.name}")
initializer = None
for input in node.input:
initializer = onnx_model.get_initializer(input)
if initializer is not None:
break
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
# we can deduce that the weights are stored in float16 precision.
max_diff = float_to_float16_max_diff(initializer)
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
is_weight_fp16_precision = max_diff < 1e-6
else:
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
keep_io_types = []
node_block_list = []
if (not is_weight_fp16_precision) and (last_matmul_node is not None) and not force_fp16_logits:
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
keep_io_types = [logits_output_name]
node_block_list = [last_matmul_node.name]
if "Add" not in op_block_list:
input_name_to_nodes = onnx_model.input_name_to_nodes()
fp32_add = 0
changed = True
add_nodes = onnx_model.get_nodes_by_op_type("Add")
while changed:
changed = False
for node in add_nodes:
if node.name not in node_block_list:
parents = onnx_model.get_parents(node, output_name_to_node)
children = onnx_model.get_children(node, input_name_to_nodes)
blocked_children = [
child for child in children if child.op_type in op_block_list or child in node_block_list
]
blocked_parents = [
parent for parent in parents if parent.op_type in op_block_list or parent in node_block_list
]
# If any child or parent is in fp32, we place the Add node to fp32.
if (len(blocked_children) + len(blocked_parents)) > 0:
node_block_list.append(node.name)
fp32_add += 1
changed = True
fp16_add = len(add_nodes) - fp32_add
logger.info(f"node counter of Add operator: fp32={fp32_add} fp16={fp16_add}")
logger.info(f"node_block_list: {node_block_list}")
parameters = {
"keep_io_types": keep_io_types,
"op_block_list": op_block_list,
"node_block_list": node_block_list,
"force_fp16_initializers": is_weight_fp16_precision,
}
logger.info(f"auto_mixed_precision parameters: {parameters}")
if use_symbolic_shape_infer:
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
else:
# Workaround when symbolic shape inference fails.
# Need enable shape_infer_before_optimization in convert_to_onnx.py as well.
from float16 import convert_float_to_float16 # noqa: PLC0415
convert_float_to_float16(
onnx_model.model,
disable_shape_infer=True,
**parameters,
)
return parameters
@staticmethod
def optimize_onnx(
onnx_model_path: str,
optimized_model_path: str,
is_float16: bool,
num_attention_heads: int,
hidden_size: int,
use_external_data_format: bool = False,
auto_mixed_precision: bool = True,
use_gpu: bool = False,
force_fp16_io: bool = False,
):
"""Optimize ONNX model with an option to convert it to use mixed precision."""
from fusion_options import FusionOptions # noqa: PLC0415
optimization_options = None
if is_float16:
optimization_options = FusionOptions("t5")
# SkipLayerNormalization is faster but might bring accuracy drop since it uses fp16 accumulation.
optimization_options.enable_skip_layer_norm = not auto_mixed_precision
m = optimize_model(
onnx_model_path,
model_type="t5",
num_heads=num_attention_heads,
hidden_size=hidden_size,
opt_level=0,
optimization_options=optimization_options,
use_gpu=use_gpu,
)
if is_float16:
if auto_mixed_precision:
T5Helper.auto_mixed_precision(m, force_fp16_logits=force_fp16_io)
else:
m.convert_model_float32_to_float16(cast_input_output=force_fp16_io)
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
@staticmethod
def verify_onnx(
model: T5Decoder | T5EncoderDecoderInit,
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
if isinstance(model, T5EncoderDecoderInit):
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
import sys
sys.path.append(os.path.dirname(__file__))
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
@@ -0,0 +1,610 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import ast
import datetime
import gc
import logging
import os
import sys
import time
import numpy as np
import psutil
import torch
import whisper
from benchmark_helper import measure_memory, setup_logger
from onnxruntime_extensions import get_library_path
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
from torch.profiler import ProfilerActivity, profile, record_function
from tqdm import trange
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
import onnxruntime as ort
logger = logging.getLogger(__name__)
def get_inputs(args: argparse.Namespace):
if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
raise Exception("Unable to auto-detect inputs for provided model")
def load_via_ffmpeg():
audio = whisper.load_audio(args.audio_path)
audio = whisper.pad_or_trim(audio)
return audio
def load_via_numpy():
with open(args.audio_path, "rb") as f:
audio = np.asarray(list(f.read()), dtype=np.uint8)
audio = np.array([audio])
return audio
inputs = {
"max_length": args.max_length,
"min_length": args.min_length,
"num_beams": args.num_beams,
"num_return_sequences": args.num_return_sequences,
"length_penalty": args.length_penalty,
"repetition_penalty": args.repetition_penalty,
}
if args.benchmark_type == "ort":
# convert_to_onnx export or ONNX E2E solution created by Olive
for k, v in inputs.items():
inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
if args.has_decoder_input_ids:
inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
if args.has_logits_processor:
inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
if args.has_temperature:
inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
# Measure time taken to load audio file
logger.info(f"Load audio: {args.audio_path}")
load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
time_fn(args, load_audio_fn, args.has_audio_stream)
audio_data = load_audio_fn(args.has_audio_stream)
if args.has_audio_stream:
# ONNX E2E solution created by Olive
inputs["audio_stream"] = audio_data
return inputs
# Measure time taken to get input features
logger.info("Feature extraction: ")
return_type = "np" if args.benchmark_type == "ort" else "pt"
processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
[audio], return_tensors=return_type, sampling_rate=args.sampling_rate
).input_features
time_fn(args, processor_fn, audio_data)
input_features = processor_fn(audio_data)
if args.benchmark_type == "ort":
# convert_to_onnx export
inputs["input_features"] = input_features
return inputs
inputs["inputs"] = input_features.to(
dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
)
inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
inputs["early_stopping"] = True
inputs["use_cache"] = True
if args.decoder_input_ids:
inputs["forced_decoder_ids"] = args.decoder_input_ids
return inputs
def get_model(args: argparse.Namespace):
model, sess_options = None, None
start_time, end_time = None, None
# There are multiple sources that the model could come from:
# 1) Benchmark Whisper from Hugging Face
# 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
# 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
start_time = time.time()
model = AutoModelForSpeechSeq2Seq.from_pretrained(
source,
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
use_cache=True,
).to(args.target_device)
end_time = time.time()
if args.benchmark_type == "hf-pt-compile":
model = torch.compile(model)
elif args.benchmark_type in {"hf-ort", "ort"}:
sess_options = ort.SessionOptions()
sess_options.enable_profiling = args.profile
sess_options.register_custom_ops_library(get_library_path())
if args.verbose:
sess_options.log_verbosity_level = 1
sess_options.log_severity_level = 1
if args.tune:
ort.set_default_logger_severity(0)
ort.set_default_logger_verbosity(0)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
if args.benchmark_type == "hf-ort":
# Optimum export
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
start_time = time.time()
model = ORTModelForSpeechSeq2Seq.from_pretrained(
args.hf_ort_dir_path,
provider=provider,
provider_options=provider_options,
session_options=sess_options,
use_io_binding=True, # Avoid memory copy overhead
)
end_time = time.time()
if args.benchmark_type == "ort":
# convert_to_onnx.py export
logger.info(f"Loading model from {args.ort_model_path}")
start_time = time.time()
model = ort.InferenceSession(
args.ort_model_path,
sess_options,
providers=[args.execution_provider],
)
end_time = time.time()
logger.info(f"Loaded model in {end_time - start_time} s")
return model
def time_fn(args, fn, inputs):
warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
torch_device = torch.device(args.target_device)
# Warm up
warmup_range = (
range(args.warmup_runs)
if args.benchmark_type == "ort"
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
)
if args.verbose:
outputs = fn(warmup_inputs)
logger.info(outputs)
for _ in warmup_range:
fn(warmup_inputs)
# Benchmark
if args.device != "cpu":
torch.cuda.synchronize(torch_device)
start_time = time.time()
bench_range = (
range(args.num_runs)
if args.benchmark_type == "ort"
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
)
for _ in bench_range:
fn(benchmark_inputs)
if args.device != "cpu":
torch.cuda.synchronize(torch_device)
end_time = time.time()
# Newline print after trange in order to print metrics on new lines without progress bar on same line
if args.benchmark_type != "ort":
logger.info("")
batch_size = 1
latency = (end_time - start_time) / args.num_runs
throughput = batch_size / latency
logger.info(f"Latency: {latency} s")
logger.info(f"Throughput: {throughput} qps")
return
def profile_fn(args, fn, inputs, inputs_type):
# Filename prefix format:
# "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
filename = None
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
# Profile PyTorch kernels
with profile( # noqa: SIM117
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
) as prof:
with record_function("model_inference"):
fn(inputs)
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
filename = os.path.join(args.log_folder, f"{prefix}.log")
with open(filename, "w") as f:
f.write(prof_data)
else:
# Profile ORT kernels
fn(inputs)
# Set new log name for ORT profile log generated
filename = f"{prefix}.json"
return filename
def measure_fn(args, fn, inputs):
# Measure CPU usage
pid = os.getpid()
process = psutil.Process(pid)
process.cpu_percent(interval=0.1)
fn(inputs)
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
# Measure memory usage
gc.collect()
torch.cuda.empty_cache()
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
# Flush output so memory usage is printed
sys.stdout.flush()
def run_hf_inference(args, inputs, model):
# Inference steps to measure
def get_pred_ids(inputs):
# Inference pass with predicted token ids generation
predicted_ids = model.generate(**inputs)
return predicted_ids
def gen_and_dec(inputs):
# Inference pass with generation and decoding
predicted_ids = get_pred_ids(inputs)
transcription = []
for _ in range(args.num_return_sequences):
transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
return predicted_ids, transcription
# Examples of other inference steps that can be measured:
# To use, uncomment the function and assign it to `generate_fn`
# def get_logits(inputs):
# # Inference pass without decoding
# outputs = model(**inputs)
# return outputs
generate_fn = gen_and_dec
if args.benchmark_type == "hf-pt-compile":
# Run forward pass once with each set of inputs to process through Dynamo
generate_fn(inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
if args.benchmark_type == "hf-ort":
# Rename log files per model component and turn profiling off to stop appending to log
new_prefix = new_logname[: -len(".json")]
old_logname = model.encoder.session.end_profiling()
new_logname = new_prefix + "-encoder.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
old_logname = model.decoder.session.end_profiling()
new_logname = new_prefix + "-decoder.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
old_logname = model.decoder_with_past.session.end_profiling()
new_logname = new_prefix + "-decoder-with-past.json"
if os.path.isfile(old_logname):
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# PyTorch evaluations
logger.info("\nEvaluating PyTorch...")
time_fn(args, generate_fn, inputs)
predicted_ids, transcription = generate_fn(inputs)
logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
logger.info(f"Transcription: {transcription[0]}")
measure_fn(args, generate_fn, inputs)
def run_ort_inference(args, inputs, model):
def prepare_ort_inputs(inputs, warmup=False):
# Check that all model inputs will be provided
model_inputs = {model_input.name for model_input in model.get_inputs()}
user_inputs = set(inputs.keys())
missing_inputs = model_inputs - user_inputs
if len(missing_inputs):
logger.error(f"The following model inputs are missing: {missing_inputs}")
raise Exception("There are missing inputs to the model. Please add them and try again.")
if warmup and args.tune:
inputs["min_length"] = inputs["max_length"]
# Remove unnecessary inputs from model inputs
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
del inputs[unnecessary_input]
# Add IO bindings for non-CPU execution providers
if args.device != "cpu":
io_binding = model.io_binding()
for k, v in inputs.items():
io_binding.bind_cpu_input(k, v)
for output in model.get_outputs():
io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
return io_binding
return inputs
def with_io_binding(io_binding):
# Inference pass with IO binding
model.run_with_iobinding(io_binding)
return io_binding
def without_io_binding(inputs):
# Inference pass without IO binding
outputs = model.run(None, inputs)
return outputs
def handle_output(output):
if args.eos_token_id in output:
first_end = np.where(output == args.eos_token_id)[0][0]
return output[: first_end + 1]
return output
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
ort_inputs = prepare_ort_inputs(inputs)
if args.profile:
new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
# Turn profiling off to stop appending to log file
old_logname = model.end_profiling()
logger.warning(f"Renaming {old_logname} to {new_logname}")
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
return
# ORT evaluation
logger.info("\nEvaluating ONNX Runtime...")
ort_evaluate_inputs = ort_inputs
if args.tune:
ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True)
ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs)
time_fn(args, generate_fn, ort_evaluate_inputs)
ort_outputs = generate_fn(ort_inputs)
if args.device != "cpu":
ort_outputs = ort_outputs.copy_outputs_to_cpu()
ort_outputs = ort_outputs[0]
if args.has_audio_stream:
# ONNX E2E model from Olive produces transcribed output
logger.info(f"Transcription: {ort_outputs[0][0]}")
else:
# convert_to_onnx model produces generated ids
actual_output = handle_output(ort_outputs[0][0])
logger.info(f"Generated token length: {len(actual_output)} tokens")
transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
# print to stdout as the output for comparison
print(f"{transcription}")
measure_fn(args, generate_fn, ort_inputs)
def run_inference(args, inputs, model):
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
run_hf_inference(args, inputs, model)
elif args.benchmark_type == "ort":
run_ort_inference(args, inputs, model)
else:
raise Exception(f"Cannot recognize {args.benchmark_type}")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-bt",
"--benchmark-type",
type=str,
required=True,
choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
)
parser.add_argument(
"-m",
"--model-name",
type=str,
required=True,
help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
)
parser.add_argument(
"-p",
"--precision",
type=str,
required=True,
default="fp32",
choices=["int8", "fp16", "fp32"],
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
)
parser.add_argument(
"--hf-pt-model-path",
type=str,
default="",
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
default="",
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
)
parser.add_argument(
"--ort-model-path",
type=str,
default="",
help="Path to ONNX model",
)
# Args for running and evaluating the model
parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
parser.add_argument("-n", "--num-runs", type=int, default=10)
parser.add_argument("--seed", type=int, default=2)
# Optional args:
parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
# Args for decoding logic
# Required args:
parser.add_argument("--max-length", type=int, default=448)
parser.add_argument("--min-length", type=int, default=0)
parser.add_argument("--num-beams", type=int, default=1)
parser.add_argument("--num-return-sequences", type=int, default=1)
parser.add_argument("--length-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
# Optional args for E2E solution:
parser.add_argument(
"--decoder-input-ids",
type=str,
default="[]",
help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
)
parser.add_argument(
"--logits-processor",
type=int,
default=1,
help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Temperature value for generation.",
)
# Args for accessing detailed info
parser.add_argument("--profile", default=False, action="store_true")
parser.add_argument(
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
)
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
parser.add_argument("--verbose", default=False, action="store_true")
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
parser.add_argument(
"--tune",
default=False,
action="store_true",
help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel",
)
args = parser.parse_args()
# Set seed properties
np.random.seed(args.seed)
torch.manual_seed(args.seed)
args.monitor_type = args.device
# Set runtime properties
if "ort" in args.benchmark_type:
args.execution_provider = f"{args.device.upper()}ExecutionProvider"
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
elif args.execution_provider == "ROCMExecutionProvider":
args.execution_provider = (
args.execution_provider,
{
"device_id": args.device_id,
"tunable_op_enable": 1,
"tunable_op_tuning_enable": 1 if args.tune else 0,
},
)
args.device = "cuda"
# Check that model paths have been specified for any benchmarking with ORT
if args.benchmark_type == "hf-ort":
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
if args.benchmark_type == "ort":
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
# Convert decoder_input_ids string to list of ids
# (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
return args
def main():
args = parse_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
config = WhisperConfig.from_pretrained(args.model_name)
processor = WhisperProcessor.from_pretrained(args.model_name)
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
use_fp16 = args.precision == "fp16"
setattr(args, "processor", processor) # noqa: B010
setattr(args, "target_device", target_device) # noqa: B010
setattr(args, "use_fp16", use_fp16) # noqa: B010
setattr(args, "has_audio_stream", False) # noqa: B010
setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
# Measure cost to transcribe audio
model = get_model(args)
if args.benchmark_type == "ort":
# Check for optional inputs that could have been added during export
ort_model_inputs = {model_input.name for model_input in model.get_inputs()}
args.has_audio_stream = "audio_stream" in ort_model_inputs
setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010
if args.decoder_input_ids == []:
args.decoder_input_ids = [config.decoder_start_token_id]
inputs = get_inputs(args)
run_inference(args, inputs, model)
if __name__ == "__main__":
main()
@@ -0,0 +1,526 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import datetime
import json
import logging
import os
import subprocess
import librosa
import torch
from benchmark_helper import setup_logger
from metrics import BenchmarkRecord
from transformers import WhisperConfig, WhisperProcessor
logger = logging.getLogger(__name__)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-a",
"--audio-path",
type=str,
required=True,
help="Path to folder of audio files for E2E evaluation",
)
parser.add_argument(
"-l",
"--language",
default=None,
help="Language of audio file",
)
parser.add_argument(
"-t",
"--task",
default=None,
choices=["transcribe", "translate"],
help="Task to complete",
)
parser.add_argument(
"-w",
"--warmup-runs",
type=int,
default=5,
)
parser.add_argument(
"-n",
"--num-runs",
type=int,
default=10,
)
parser.add_argument(
"--hf-pt-eager",
default=False,
action="store_true",
help="Benchmark in PyTorch without `torch.compile`",
)
parser.add_argument(
"--hf-pt-compile",
default=False,
action="store_true",
help="Benchmark in PyTorch with `torch.compile`",
)
parser.add_argument(
"--hf-ort-dir-path",
type=str,
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
)
parser.add_argument(
"--ort-model-path",
type=str,
help="Path to ONNX model for ORT benchmarking",
)
parser.add_argument(
"--model-name",
type=str,
required=True,
help="Model name in Hugging Face (e.g. openai/whisper-large-v2)",
)
parser.add_argument(
"--precision",
type=str,
required=True,
choices=["int8", "fp16", "fp32"],
help="Precision to run model",
)
parser.add_argument(
"--device",
type=str,
required=True,
choices=["cpu", "cuda", "rocm"],
help="Device to benchmark models",
)
parser.add_argument(
"--device-id",
type=int,
default=0,
help="GPU device ID",
)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="Print detailed logs",
)
parser.add_argument(
"--timeout",
type=int,
default=5,
help="Number of mins to attempt the benchmark before moving on",
)
parser.add_argument(
"--log-folder",
type=str,
default=None,
help="Path to folder to save logs and results",
)
parser.add_argument("--tune", default=False, action="store_true")
args = parser.parse_args()
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
log_folder_name = f"./{args.model_size}-{args.precision}"
if not args.log_folder:
args.log_folder = log_folder_name
os.makedirs(args.log_folder, exist_ok=True)
# Convert timeout value to secs
args.timeout *= 60
return args
def process_log_file(device_id, log_file, base_results):
entries = []
# Detect steps in speech pipeline
step = None
load_audio_pattern = "Load audio: "
feat_ext_pattern = "Feature extraction: "
pytorch_pattern = "Evaluating PyTorch..."
onnxruntime_pattern = "Evaluating ONNX Runtime..."
load_audio_latency_s, load_audio_throughput_s = None, None
feat_ext_latency_s, feat_ext_throughput_s = None, None
token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None
throughput, memory = None, None
# Detect metrics
latency_pattern = "Latency: "
throughput_pattern = "Throughput: "
token_length_pattern = "Generated token length: "
memory_pattern = "peak="
with open(log_file) as f:
for input_line in f:
line = input_line.replace("\n", "")
# Get step in speech recognition pipeline
if load_audio_pattern in line:
step = "load-audio"
elif feat_ext_pattern in line:
step = "feature-extraction"
elif pytorch_pattern in line or onnxruntime_pattern in line:
step = "process"
# Check metrics
if latency_pattern in line:
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
elif throughput_pattern in line:
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
if step == "load-audio":
load_audio_latency_s, load_audio_throughput_s = latency_s, throughput
step = None
if step == "feature-extraction":
feat_ext_latency_s, feat_ext_throughput_s = latency_s, throughput
step = None
elif token_length_pattern in line:
token_length = int(line[len(token_length_pattern) : line.rfind(" ")])
per_token_latency_s = latency_s / token_length
per_token_latency_ms = per_token_latency_s * 1000
elif memory_pattern in line:
if "CPU" in line:
# Example format for log entry:
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
else:
# Example format for log entry:
# GPU memory usage: before=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1638.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}, peak=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1780.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}]
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
usage = json.loads(peak)[device_id]["max_used_MB"]
memory = float(usage) / 1000
# Calculate real-time factor (RTF):
# RTF = total latency / audio duration
total_latency = (
(load_audio_latency_s if load_audio_latency_s else 0)
+ (feat_ext_latency_s if feat_ext_latency_s else 0)
+ (latency_s if latency_s else 0)
)
audio_duration = base_results[-1]
rtf = (total_latency / audio_duration) if audio_duration else -1
logger.info(f"Total latency: {total_latency} s")
logger.info(f"Audio duration: {audio_duration} s")
logger.info(f"Real-time factor: {rtf}")
# Append log entry to list of entries
entry = base_results + [ # noqa: RUF005
token_length,
load_audio_latency_s,
load_audio_throughput_s,
feat_ext_latency_s if feat_ext_latency_s else -1,
feat_ext_throughput_s if feat_ext_throughput_s else -1,
latency_s,
per_token_latency_ms,
throughput,
memory,
rtf,
]
entries.append(entry)
return entries
def save_results(results, filename):
import pandas as pd # noqa: PLC0415
df = pd.DataFrame(
results,
columns=[
"Warmup Runs",
"Measured Runs",
"Model Name",
"Engine",
"Precision",
"Device",
"Audio File",
"Duration (s)",
"Token Length",
"Load Audio Latency (s)",
"Load Audio Throughput (qps)",
"Feature Extractor Latency (s)",
"Feature Extractor Throughput (qps)",
"Latency (s)",
"Per Token Latency (ms/token)",
"Throughput (qps)",
"Memory (GB)",
"Real Time Factor (RTF)",
],
)
# Set column types
df["Warmup Runs"] = df["Warmup Runs"].astype("int")
df["Measured Runs"] = df["Measured Runs"].astype("int")
df["Duration (s)"] = df["Duration (s)"].astype("float")
df["Token Length"] = df["Token Length"].astype("int")
df["Load Audio Latency (s)"] = df["Load Audio Latency (s)"].astype("float")
df["Load Audio Throughput (qps)"] = df["Load Audio Throughput (qps)"].astype("float")
df["Feature Extractor Latency (s)"] = df["Feature Extractor Latency (s)"].astype("float")
df["Feature Extractor Throughput (qps)"] = df["Feature Extractor Throughput (qps)"].astype("float")
df["Latency (s)"] = df["Latency (s)"].astype("float")
df["Per Token Latency (ms/token)"] = df["Per Token Latency (ms/token)"].astype("float")
df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
df["Real Time Factor (RTF)"] = df["Real Time Factor (RTF)"].astype("float")
# get package name and version
import pkg_resources # noqa: PLC0415
installed_packages = pkg_resources.working_set
installed_packages_list = sorted(
[f"{i.key}=={i.version}" for i in installed_packages if i.key in ["onnxruntime", "onnxruntime-gpu"]]
)
ort_pkg_name = ""
ort_pkg_version = ""
if installed_packages_list:
ort_pkg_name = installed_packages_list[0].split("==")[0]
ort_pkg_version = installed_packages_list[0].split("==")[1]
# Save results to csv with standard format
records = []
for _, row in df.iterrows():
if row["Engine"] == "onnxruntime":
record = BenchmarkRecord(
row["Model Name"], row["Precision"], row["Engine"], row["Device"], ort_pkg_name, ort_pkg_version
)
else:
record = BenchmarkRecord(
row["Model Name"], row["Precision"], row["Engine"], row["Device"], torch.__name__, torch.__version__
)
record.config.customized["audio_file"] = row["Audio File"]
record.config.warmup_runs = row["Warmup Runs"]
record.config.measured_runs = row["Measured Runs"]
record.metrics.customized["duration"] = row["Duration (s)"]
record.metrics.customized["token_length"] = row["Token Length"]
record.metrics.customized["load_audio_latency"] = row["Load Audio Latency (s)"]
record.metrics.customized["load_audio_throughput"] = row["Load Audio Throughput (qps)"]
record.metrics.customized["feature_extractor_latency_s"] = row["Feature Extractor Latency (s)"]
record.metrics.customized["feature_extractor_throughput_qps"] = row["Feature Extractor Throughput (qps)"]
record.metrics.customized["per_token_latency_ms"] = row["Per Token Latency (ms/token)"]
record.metrics.customized["rtf"] = row["Real Time Factor (RTF)"]
record.metrics.latency_ms_mean = row["Latency (s)"] * 1000
record.metrics.throughput_qps = row["Throughput (qps)"]
record.metrics.max_memory_usage_GB = row["Memory (GB)"]
records.append(record)
BenchmarkRecord.save_as_csv(filename, records)
BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
logger.info(f"Results saved in {filename}!")
def benchmark(args, benchmark_cmd, engine, audio_file, duration):
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
log_path = os.path.join(args.log_folder, log_filename)
with open(log_path, "w") as log_file:
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
try:
process.wait(args.timeout)
except subprocess.TimeoutExpired:
process.kill()
# Create entries for csv
logger.info("Gathering data from log files...")
base_results = [
args.warmup_runs,
args.num_runs,
args.model_name,
engine,
args.precision,
args.device,
audio_file,
duration,
]
results = process_log_file(args.device_id, log_path, base_results)
return results
def main():
args = get_args()
setup_logger(args.verbose)
logger.info(args.__dict__)
torch.backends.cudnn.benchmark = True
config = WhisperConfig.from_pretrained(args.model_name)
processor = WhisperProcessor.from_pretrained(args.model_name)
# Calculate forced decoder input ids
hf_forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
ort_forced_decoder_ids = [config.decoder_start_token_id] + [token_id[1] for token_id in hf_forced_decoder_ids]
hf_decoder_input_ids_cmd = (
["--decoder-input-ids", str(hf_forced_decoder_ids)] if args.language and args.task else []
)
ort_decoder_input_ids_cmd = (
["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else []
)
ort_tune_cmd = ["--tune"] if args.tune else []
all_results = []
for audio_file in os.listdir(args.audio_path):
audio_path = os.path.join(args.audio_path, audio_file)
try:
duration = librosa.get_duration(path=audio_path)
except Exception as e:
duration = -1
logger.warning(f"An error occurred while trying to calculate the audio duration: {e}", exc_info=True)
logger.warning(
f"If you get an error that says:\n\tsoundfile.LibsndfileError: Error opening '{audio_file}': File contains data in an unknown format.\nyou may not have installed `ffmpeg` in addition to installing `librosa`."
)
logger.info(f"Testing {audio_path}...")
# Benchmark PyTorch without torch.compile
if args.hf_pt_eager:
benchmark_cmd = [ # noqa: RUF005
"python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"hf-pt-eager",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
] + hf_decoder_input_ids_cmd
logger.info("Benchmark PyTorch without torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration)
all_results.extend(results)
# Benchmark PyTorch with torch.compile
if args.hf_pt_compile:
benchmark_cmd = [ # noqa: RUF005
"python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"hf-pt-compile",
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
] + hf_decoder_input_ids_cmd
logger.info("Benchmark PyTorch with torch.compile")
results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration)
all_results.extend(results)
# Benchmark Optimum + ONNX Runtime
if args.hf_ort_dir_path:
benchmark_cmd = [ # noqa: RUF005
"python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"hf-ort",
"--hf-ort-dir-path",
args.hf_ort_dir_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
] + hf_decoder_input_ids_cmd
logger.info("Benchmark Optimum + ONNX Runtime")
results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration)
all_results.extend(results)
# Benchmark ONNX Runtime
if args.ort_model_path:
benchmark_cmd = (
[ # noqa: RUF005
"python",
"-m",
"models.whisper.benchmark",
"--audio-path",
audio_path,
"--benchmark-type",
"ort",
"--ort-model-path",
args.ort_model_path,
"--model-name",
args.model_name,
"--precision",
args.precision,
"--device",
args.device,
"--device-id",
str(args.device_id),
"--warmup-runs",
str(args.warmup_runs),
"--num-runs",
str(args.num_runs),
"--log-folder",
args.log_folder,
]
+ ort_decoder_input_ids_cmd
+ ort_tune_cmd
)
logger.info("Benchmark ONNX Runtime")
results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration)
all_results.extend(results)
csv_file = f"{args.model_size}-{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_results, os.path.join(args.log_folder, csv_file))
if __name__ == "__main__":
main()
@@ -0,0 +1,573 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import logging
import os
import torch
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
from whisper_chain import chain_model
from whisper_encoder import WhisperEncoder
from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper
from onnxruntime import quantization
logger = logging.getLogger("")
PROVIDERS = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
"rocm": "ROCMExecutionProvider",
}
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
conversion_args = parser.add_argument_group("Conversion Process Args")
optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)")
optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)")
quant_args = parser.add_argument_group("INT8 Quantization Args")
#################################
# Conversion options for Whisper
#################################
conversion_args.add_argument(
"-m",
"--model_name_or_path",
required=False,
default=PRETRAINED_WHISPER_MODELS[0],
type=str,
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS),
)
conversion_args.add_argument(
"--model_impl",
required=False,
default="hf",
choices=["hf", "openai"],
type=str,
help="Select implementation for export of encoder and decoder subgraphs",
)
conversion_args.add_argument(
"--cache_dir",
required=False,
type=str,
default=os.path.join(".", "cache_models"),
help="Directory to cache pre-trained models",
)
conversion_args.add_argument(
"--output",
required=False,
type=str,
default=os.path.join(".", "onnx_models"),
help="Output directory",
)
conversion_args.add_argument(
"-o",
"--optimize_onnx",
required=False,
action="store_true",
help="Use optimizer.py to optimize onnx model",
)
conversion_args.set_defaults(optimize_onnx=False)
conversion_args.add_argument(
"--use_gpu",
required=False,
action="store_true",
help="Use GPU for model inference",
)
conversion_args.set_defaults(use_gpu=False)
conversion_args.add_argument(
"-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization",
)
conversion_args.add_argument(
"--use_int64_inputs",
required=False,
action="store_true",
help="Use int64 instead of int32 for input_ids and attention_mask.",
)
conversion_args.set_defaults(use_int64_inputs=False)
conversion_args.add_argument(
"-r",
"--provider",
required=False,
type=str,
default="cpu",
choices=list(PROVIDERS.keys()),
help="Provider to benchmark. Default is CPUExecutionProvider.",
)
conversion_args.add_argument(
"--verbose",
required=False,
action="store_true",
help="Enable verbose logging",
)
conversion_args.set_defaults(verbose=False)
conversion_args.add_argument(
"-e",
"--use_external_data_format",
required=False,
action="store_true",
help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.",
)
conversion_args.set_defaults(use_external_data_format=False)
conversion_args.add_argument(
"-w",
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing ONNX model",
)
conversion_args.set_defaults(overwrite=False)
conversion_args.add_argument(
"--separate_encoder_and_decoder_init",
required=False,
action="store_true",
help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.",
)
conversion_args.set_defaults(separate_encoder_and_decoder_init=False)
conversion_args.add_argument(
"--no_beam_search_op",
required=False,
action="store_true",
help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.",
)
conversion_args.set_defaults(no_beam_search_op=False)
conversion_args.add_argument(
"--use_decoder_masked_mha",
required=False,
action="store_true",
help="Use DecoderMaskedMultiHeadAttention kernel for improved performance. This is currently an experimental feature.",
)
conversion_args.set_defaults(use_decoder_masked_mha=False)
#############################################################
# Optional inputs for Whisper
# (listed below in the order that WhisperBeamSearch expects)
#############################################################
optional_inputs.add_argument(
"-v",
"--use_vocab_mask",
required=False,
action="store_true",
help="Use vocab_mask as an extra graph input to enable specific logits processing",
)
optional_inputs.set_defaults(use_vocab_mask=False)
optional_inputs.add_argument(
"-u",
"--use_prefix_vocab_mask",
required=False,
action="store_true",
help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
)
optional_inputs.set_defaults(use_prefix_vocab_mask=False)
optional_inputs.add_argument(
"-f",
"--use_forced_decoder_ids",
required=False,
action="store_true",
help="Use decoder_input_ids as an extra graph input to the beam search op",
)
optional_inputs.set_defaults(use_forced_decoder_ids=False)
optional_inputs.add_argument(
"-l",
"--use_logits_processor",
required=False,
action="store_true",
help="Use logits_processor as an extra graph input to enable specific logits processing",
)
optional_inputs.set_defaults(use_specific_logits_processor=False)
optional_inputs.add_argument(
"--collect_cross_qk",
required=False,
action="store_true",
help="Beam search model collect stacked cross QK.",
)
optional_inputs.set_defaults(collect_cross_qk=False)
optional_inputs.add_argument(
"--extra_decoding_ids",
required=False,
action="store_true",
help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
)
optional_inputs.set_defaults(extra_decoding_ids=False)
optional_inputs.add_argument(
"-t",
"--use_temperature",
required=False,
action="store_true",
help="Use temperature as an extra graph input for the WhisperBeamSearch op",
)
optional_inputs.set_defaults(use_temperature=False)
optional_inputs.add_argument(
"--no_repeat_ngram_size",
type=int,
default=0,
help="default to 0",
)
#############################################################
# Optional outputs for Whisper
# (listed below in the order that WhisperBeamSearch expects)
#############################################################
optional_outputs.add_argument(
"--output_sequence_scores",
required=False,
action="store_true",
help="Beam search model output scores for each generated sequence.",
)
optional_outputs.set_defaults(output_sequence_scores=False)
optional_outputs.add_argument(
"--output_scores",
required=False,
action="store_true",
help="Beam search model output scores over vocab per generated token.",
)
optional_outputs.set_defaults(output_scores=False)
optional_outputs.add_argument(
"--output_cross_qk",
required=False,
action="store_true",
help="Beam search model output collected qk as output. Also hint collect_cross_qk",
)
optional_outputs.set_defaults(output_cross_qk=False)
optional_outputs.add_argument(
"--cross_qk_onnx_model",
required=False,
type=str,
default=None,
help="The model which consumes cross_qk outputs.",
)
optional_outputs.add_argument(
"--output_no_speech_probs",
required=False,
action="store_true",
help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
)
optional_outputs.set_defaults(output_no_speech_probs=False)
###################################
# Quantization options for Whisper
###################################
quant_args.add_argument(
"--quantize_embedding_layer",
required=False,
action="store_true",
help="Quantize MatMul, GEMM, and Gather.",
)
quant_args.set_defaults(quantize_embedding_layer=False)
quant_args.add_argument(
"--quantize_per_channel",
required=False,
action="store_true",
help="Quantize weights per each channel.",
)
quant_args.set_defaults(quantize_per_channel=False)
quant_args.add_argument(
"--quantize_reduce_range",
required=False,
action="store_true",
help="Quantize weights with 7 bits.",
)
quant_args.set_defaults(quantize_reduce_range=False)
args = parser.parse_args(argv)
# Collect cross QKs if either flag is enabled
args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk
# FP32 CPU can be supported here once the DMMHA CPU kernel bugs are fixed
args.use_decoder_masked_mha = args.use_decoder_masked_mha and args.provider == "cuda"
return args
def export_onnx_models(
model_name_or_path,
model_impl,
cache_dir,
output_dir,
use_gpu,
use_external_data_format,
optimize_onnx,
precision,
verbose,
use_forced_decoder_ids: bool = False,
merge_encoder_and_decoder_init: bool = True,
no_beam_search_op: bool = False,
use_decoder_masked_mha: bool = False,
output_qk: bool = False,
overwrite: bool = False,
use_int32_inputs: bool = True,
quantize_embedding_layer: bool = False,
quantize_per_channel: bool = False,
quantize_reduce_range: bool = False,
provider: str = "cpu",
):
device = torch.device("cuda" if use_gpu else "cpu")
models = WhisperHelper.load_model(
model_name_or_path,
model_impl,
cache_dir,
device,
torch.float16 if precision == Precision.FLOAT16 else torch.float32,
merge_encoder_and_decoder_init,
no_beam_search_op,
output_qk,
)
config = models["decoder"].config
if (not use_external_data_format) and (config.num_hidden_layers > 24):
logger.warning("You MUST pass `--use_external_data_format` because model size > 2GB")
raise Exception("Please pass `--use_external_data_format` for this model.")
output_paths = []
for name, model in models.items():
print(f"========> Handling {name} model......")
filename_suffix = "_" + name
onnx_path = WhisperHelper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix,
new_folder=False,
)
# Export to ONNX
if overwrite or not os.path.exists(onnx_path):
logger.info(f"Exporting ONNX model to {onnx_path}")
WhisperHelper.export_onnx(
model,
onnx_path,
PROVIDERS[provider],
verbose,
use_external_data_format,
use_fp16_inputs=(precision == Precision.FLOAT16),
use_int32_inputs=use_int32_inputs,
use_encoder_hidden_states=(name == "decoder_init"),
use_kv_cache_inputs=(name == "decoder"),
)
else:
logger.info(f"Skip exporting: existing ONNX model {onnx_path}")
# Optimize ONNX model
if optimize_onnx or precision != Precision.FLOAT32:
output_path = WhisperHelper.get_onnx_path(
output_dir,
model_name_or_path,
suffix=filename_suffix + "_" + str(precision),
new_folder=False,
)
if overwrite or not os.path.exists(output_path):
if optimize_onnx:
logger.info(f"Optimizing model to {output_path}")
WhisperHelper.optimize_onnx(
onnx_path,
output_path,
precision == Precision.FLOAT16,
model.config.encoder_attention_heads,
model.config.d_model,
model.config.decoder_layers,
use_external_data_format,
use_gpu=use_gpu,
provider=provider,
is_decoder=(name == "decoder"),
no_beam_search_op=no_beam_search_op,
use_decoder_masked_mha=use_decoder_masked_mha,
output_qk=output_qk,
)
# Remove old ONNX model and old data file
if os.path.exists(onnx_path):
os.remove(onnx_path)
if os.path.exists(onnx_path + ".data"):
os.remove(onnx_path + ".data")
onnx_path = output_path
if isinstance(model, WhisperEncoder):
model.verify_onnx(
onnx_path,
PROVIDERS[provider],
use_fp16_inputs=(precision == Precision.FLOAT16),
)
else:
model.verify_onnx(
onnx_path,
PROVIDERS[provider],
use_fp16_inputs=(precision == Precision.FLOAT16),
use_int32_inputs=use_int32_inputs,
)
if precision == Precision.INT8:
quantization.quantize_dynamic(
onnx_path,
output_path,
op_types_to_quantize=(
["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"]
),
use_external_data_format=use_external_data_format,
per_channel=quantize_per_channel,
reduce_range=quantize_reduce_range,
extra_options={"MatMulConstBOnly": True},
)
else:
logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
else:
output_path = onnx_path
output_paths.append(output_path)
return output_paths
def main(argv=None):
args = parse_arguments(argv)
setup_logger(args.verbose)
logger.info(f"Arguments:{args}")
cache_dir = args.cache_dir
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
prepare_environment(cache_dir, output_dir, args.use_gpu)
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 requires --use_gpu"
output_paths = export_onnx_models(
args.model_name_or_path,
args.model_impl,
cache_dir,
output_dir,
args.use_gpu,
args.use_external_data_format,
args.optimize_onnx,
args.precision,
args.verbose,
args.use_forced_decoder_ids,
not args.separate_encoder_and_decoder_init,
args.no_beam_search_op,
args.use_decoder_masked_mha,
args.output_cross_qk,
args.overwrite,
not args.use_int64_inputs,
args.quantize_embedding_layer,
args.quantize_per_channel,
args.quantize_reduce_range,
args.provider,
)
max_diff = 0
if not args.no_beam_search_op:
logger.info("Chaining model ... :")
args.beam_model_output_dir = WhisperHelper.get_onnx_path(
output_dir,
args.model_name_or_path,
suffix="_beamsearch",
new_folder=False,
)
for path in output_paths:
if "encoder_decoder" in path or "encoder" in path:
args.encoder_path = path
elif "decoder" in path:
args.decoder_path = path
chain_model(args)
output_paths.append(args.beam_model_output_dir)
# Check chained model
ort_session = create_onnxruntime_session(
args.beam_model_output_dir,
use_gpu=args.use_gpu,
provider=args.provider,
)
device = torch.device("cuda" if args.use_gpu else "cpu")
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
# Verify batched decoding with prompts for OpenAI implementation
if args.model_impl == "openai" and args.use_forced_decoder_ids:
max_diff = WhisperHelper.verify_onnx(
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
)
else:
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
else:
logger.info("PyTorch and ONNX Runtime results are close")
except Exception as e:
logger.warning(
f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
)
# Remove extra ONNX models saved in output directory
for _file in os.listdir(output_dir):
if "_beamsearch" not in _file and "_jump_times" not in _file:
path = os.path.join(output_dir, _file)
os.remove(path)
if path in output_paths:
output_paths.remove(path)
else:
# Create ancillary JSON files for ONNX Runtime GenAI and/or Hugging Face's Optimum
WhisperHelper.save_processing(
args.model_name_or_path,
args.provider,
args.separate_encoder_and_decoder_init,
args.use_decoder_masked_mha,
args.output_cross_qk,
next(iter(filter(lambda path: "encoder" in path, output_paths))),
next(iter(filter(lambda path: "decoder" in path, output_paths))),
output_dir,
cache_dir,
)
logger.info(f"Done! Outputs: {output_paths}")
return max_diff
if __name__ == "__main__":
main()
@@ -0,0 +1,331 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import onnx
from benchmark_helper import Precision
from convert_generation import (
get_shared_initializers,
update_decoder_subgraph_output_cross_attention,
update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
)
from onnx import TensorProto, helper
from transformers import WhisperConfig, WhisperTokenizer
logger = logging.getLogger(__name__)
def verify_inputs(beam_inputs, graph_inputs):
# Verify that ONNX graph's inputs match beam search op's inputs
beam_required_inputs = list(filter(lambda beam_input: beam_input, beam_inputs))
assert len(graph_inputs) == len(beam_required_inputs)
for graph_input, beam_input in zip(graph_inputs, beam_required_inputs, strict=False):
# Check if graph_input is in beam_input to handle beam_input names with the "_fp16" suffix
assert graph_input.name in beam_input
def clean_list(arr, remove_all_strings=True):
if remove_all_strings:
# Remove all empty strings in list
return list(filter(lambda elm: elm != "", arr))
# Remove empty strings at end of list
while len(arr) > 0:
if arr[-1] == "":
arr.pop()
else:
break
return arr
def chain_model(args):
# Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op
encoder_model = onnx.load_model(args.encoder_path, load_external_data=True)
encoder_model.graph.name = "encoderdecoderinit subgraph"
decoder_model = onnx.load_model(args.decoder_path, load_external_data=True)
decoder_model.graph.name = "decoder subgraph"
config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
# Create inputs/outputs for WhisperBeamSearch op
temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
beam_inputs = [
"input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features",
"max_length",
"min_length",
"num_beams",
"num_return_sequences",
"length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty",
"repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty",
"vocab_mask" if args.use_vocab_mask else "",
"prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
"", # attention mask
"decoder_input_ids" if args.use_forced_decoder_ids else "",
"logits_processor" if args.use_logits_processor else "",
"cross_qk_layer_head" if args.collect_cross_qk else "",
"extra_decoding_ids" if args.extra_decoding_ids else "",
temperature_name if args.use_temperature else "",
]
sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores"
scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores"
beam_outputs = [
"sequences",
sequence_scores_name if args.output_sequence_scores else "",
scores_name if args.output_scores else "",
"cross_qk" if args.collect_cross_qk else "",
"no_speech_probs_beam" if args.output_no_speech_probs else "",
]
graph_nodes = []
if args.precision == Precision.FLOAT16:
input_features_cast_node = helper.make_node(
"Cast",
inputs=["input_features"],
outputs=["input_features_fp16"],
name="CastInputFeaturesToFp16",
to=TensorProto.FLOAT16,
)
len_pen_cast_node = helper.make_node(
"Cast",
inputs=["length_penalty"],
outputs=["length_penalty_fp16"],
name="CastLengthPenaltyToFp16",
to=TensorProto.FLOAT16,
)
rep_pen_cast_node = helper.make_node(
"Cast",
inputs=["repetition_penalty"],
outputs=["repetition_penalty_fp16"],
name="CastRepetitionPenaltyToFp16",
to=TensorProto.FLOAT16,
)
graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node])
if args.use_temperature:
temp_cast_node = helper.make_node(
"Cast",
inputs=["temperature"],
outputs=["temperature_fp16"],
name="temperature_to_fp16",
to=TensorProto.FLOAT16,
)
graph_nodes.append(temp_cast_node)
if args.output_sequence_scores:
output_sequence_scores_cast_node = helper.make_node(
"Cast",
inputs=["sequence_scores_fp16"],
outputs=["sequence_scores"],
name="CastOutputSequenceScoresToFp32",
to=TensorProto.FLOAT,
)
graph_nodes.append(output_sequence_scores_cast_node)
if args.output_scores:
output_scores_cast_node = helper.make_node(
"Cast",
inputs=["scores_fp16"],
outputs=["scores"],
name="CastScoresToFp32",
to=TensorProto.FLOAT,
)
graph_nodes.append(output_scores_cast_node)
# Create WhisperBeamSearch op
beam_search_attrs = [
helper.make_attribute("eos_token_id", config.eos_token_id),
helper.make_attribute("pad_token_id", config.pad_token_id),
helper.make_attribute(
"decoder_start_token_id", config.decoder_start_token_id
), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0]
helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]),
helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]),
helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]),
(
helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0])
if args.output_no_speech_probs
else ""
),
helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]),
helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]),
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
helper.make_attribute("early_stopping", True),
helper.make_attribute("model_type", 2),
helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "",
]
node = helper.make_node(
"WhisperBeamSearch",
inputs=clean_list(beam_inputs, remove_all_strings=False),
outputs=clean_list(beam_outputs, remove_all_strings=False),
name="BeamSearch",
domain="com.microsoft",
)
node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True))
# Graph inputs
input_features = helper.make_tensor_value_info(
"input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
)
max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size])
prefix_vocab_mask = helper.make_tensor_value_info(
"prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size]
)
decoder_input_ids = helper.make_tensor_value_info(
"decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"]
)
logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2])
extra_decoding_ids = helper.make_tensor_value_info(
"extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"]
)
temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
graph_inputs = clean_list(
[
input_features,
max_length,
min_length,
num_beams,
num_return_sequences,
length_penalty,
repetition_penalty,
vocab_mask if args.use_vocab_mask else "",
prefix_vocab_mask if args.use_prefix_vocab_mask else "",
decoder_input_ids if args.use_forced_decoder_ids else "",
logits_processor if args.use_logits_processor else "",
cross_qk_layer_head if args.collect_cross_qk else "",
extra_decoding_ids if args.extra_decoding_ids else "",
temperature if args.use_temperature else "",
]
)
# Graph outputs
sequences = helper.make_tensor_value_info(
"sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
)
sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"])
scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"])
cross_qk = helper.make_tensor_value_info(
"cross_qk",
TensorProto.FLOAT,
["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"],
)
no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"])
graph_outputs = clean_list(
[
sequences,
sequence_scores if args.output_sequence_scores else "",
scores if args.output_scores else "",
cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "",
no_speech_probs if args.output_no_speech_probs else "",
]
)
# Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference
if hasattr(args, "use_gpu") and args.use_gpu:
if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!")
else:
logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph")
if hasattr(args, "collect_cross_qk") and args.collect_cross_qk:
update_decoder_subgraph_output_cross_attention(decoder_model.graph)
# Initializers/opsets
# Delete shared data between decoder/encoder and move to larger graph initializers
initializers = get_shared_initializers(encoder_model, decoder_model)
node.attribute.extend(
[
helper.make_attribute("decoder", decoder_model.graph),
helper.make_attribute("encoder", encoder_model.graph),
]
)
opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
graph_nodes.append(node)
if args.output_no_speech_probs:
prob_cast_node = helper.make_node(
"Cast",
inputs=["no_speech_probs_beam"],
outputs=["no_speech_probs"],
name="no_speech_probs_cast_to_fp32",
to=TensorProto.FLOAT,
)
graph_nodes.append(prob_cast_node)
# Make graph with WhisperBeamSearch op
beam_graph = helper.make_graph(
graph_nodes,
name="WhisperBeamSearch Graph",
inputs=graph_inputs,
outputs=graph_outputs,
initializer=initializers,
)
beam_graph_input_names = [gi.name for gi in graph_inputs]
beam_graph_output_names = [go.name for go in graph_outputs]
if args.cross_qk_onnx_model:
post_qk_model = onnx.load_model(args.cross_qk_onnx_model, load_external_data=True)
post_qk_graph = post_qk_model.graph
beam_graph.initializer.extend(post_qk_graph.initializer)
beam_graph.node.extend(post_qk_graph.node)
# If tensor from cross_qk_onnx_model has same name as tensor in beamsearch graph, treat them as same tensor.
# User should notice this rule when provide cross_qk_onnx_model to append to the beamsearch node.
for pgi in post_qk_graph.input:
if (
(pgi.name not in beam_graph_input_names)
and (pgi.name not in beam_graph_output_names)
and (pgi.name != "cross_qk")
):
beam_graph.input.extend([pgi])
beam_graph.output.extend(post_qk_graph.output)
# Verify graph's inputs match beam search's inputs
verify_inputs(beam_inputs, graph_inputs)
assert decoder_model.ir_version == encoder_model.ir_version
logger.info(f"Using IR version {decoder_model.ir_version} for chained model")
# Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model
beam_model = helper.make_model_gen_version(
beam_graph,
producer_name="onnxruntime.transformers",
opset_imports=opset_import,
ir_version=decoder_model.ir_version,
)
# Save WhisperBeamSearch graph and external data
if os.path.isfile(args.beam_model_output_dir):
logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}")
if os.path.exists(args.beam_model_output_dir):
os.remove(args.beam_model_output_dir)
if os.path.exists(args.beam_model_output_dir + ".data"):
os.remove(args.beam_model_output_dir + ".data")
onnx.save(
beam_model,
args.beam_model_output_dir,
save_as_external_data=args.use_external_data_format,
all_tensors_to_one_file=True,
convert_attribute=True,
location=f"{os.path.basename(args.beam_model_output_dir)}.data",
)
try:
onnx.checker.check_model(args.beam_model_output_dir, full_check=True)
except Exception as e:
logger.error(f"An error occurred while running the ONNX checker: {e}", exc_info=True) # noqa: G201
@@ -0,0 +1,464 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from itertools import chain
from pathlib import Path
import numpy as np
import onnx
import torch
from float16 import convert_float_to_float16
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from onnx import ModelProto, ValueInfoProto
from onnx_model import OnnxModel
from past_helper import PastKeyValuesHelper
from transformers import WhisperConfig
from whisper_inputs import (
convert_inputs_for_ort,
get_model_dynamic_axes,
get_sample_decoder_inputs,
group_past_key_values,
)
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class WhisperDecoder(torch.nn.Module):
"""A Whisper decoder with optional past key values"""
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
super().__init__()
self.config = config
self.device = model.device
self.model_impl = model_impl
self.no_beam_search_op = no_beam_search_op
self.decoder = None if model_impl == "openai" else model.model.decoder
self.proj_out = None if model_impl == "openai" else model.proj_out
self.model = model if model_impl == "openai" else None
self.max_source_positions = self.config.max_source_positions
self.num_heads = self.config.decoder_attention_heads
self.head_size = self.config.d_model // self.num_heads
def hf_forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
past_key_values: list[tuple[torch.Tensor]] | None = None,
):
outputs = self.decoder(
encoder_hidden_states=encoder_hidden_states,
input_ids=decoder_input_ids,
past_key_values=past_key_values,
use_cache=True,
)
logits = self.proj_out(outputs.last_hidden_state)
present_key_values = outputs.past_key_values
if past_key_values is None:
# Return present_self_* and present_cross_* for decoder-init
return logits, present_key_values
# Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
# (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
# After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1), ...,
# (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1), ...
present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(present_key_values)
# Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
return logits, present_self
def oai_forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
past_key_values: list[tuple[torch.Tensor]] | None = None,
):
past_kv_cache = {}
if past_key_values is not None:
# Convert past KV caches (BxNxSxH --> BxSxNxH --> BxSxD) for OpenAI's forward pass
self_attn_kv_caches, cross_attn_kv_caches = group_past_key_values(past_key_values)
self_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in self_attn_kv_caches]
self_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in self_attn_kv_caches]
cross_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in cross_attn_kv_caches]
cross_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in cross_attn_kv_caches]
for idx, block in enumerate(self.model.decoder.blocks):
past_kv_cache[block.attn.key] = self_attn_kv_caches[2 * idx]
past_kv_cache[block.attn.value] = self_attn_kv_caches[2 * idx + 1]
past_kv_cache[block.cross_attn.key] = cross_attn_kv_caches[2 * idx]
past_kv_cache[block.cross_attn.value] = cross_attn_kv_caches[2 * idx + 1]
# Install OpenAI's hooks on the forward pass of each nn.Linear for key and value
# since the hooks will capture the output of the key and value MatMuls, which
# represent the current keys and values.
#
# For OpenAI's forward pass, the hook function will also perform the concat
# operation (past_kv + curr_kv --> pres_kv) if needed. However, the ONNX model
# will not contain this concat operation because the present KV caches aren't
# returned by OpenAI's forward pass.
kv_cache, hooks = self.model.install_kv_cache_hooks()
# Run forward pass
# NOTE: There is a bug with openai-whisper==20240930 with the introduction of SDPA.
# In the Whisper codebase, the following line
#
# is_causal = mask is not None and n_ctx > 1
#
# has been added where `mask` is a torch tensor. The right-hand side evaluates to `tensor(True/False)`
# but `is_causal` only accepts the boolean value. The fix is to apply `.item()` after the right-hand
# side has been evaluated. In other words, the line should be
#
# is_causal = (mask is not None and n_ctx > 1).item()
#
# instead.
logits = self.model.decoder(x=decoder_input_ids, xa=encoder_hidden_states, kv_cache=past_kv_cache)
# Re-do concat operation on self attention KV caches for ONNX export (if past self attention KV caches exist)
if past_key_values is not None:
for block in self.model.decoder.blocks:
kv_cache[block.attn.key] = torch.cat(
[past_kv_cache[block.attn.key], kv_cache[block.attn.key]], dim=1
).detach()
kv_cache[block.attn.value] = torch.cat(
[past_kv_cache[block.attn.value], kv_cache[block.attn.value]], dim=1
).detach()
present_self, present_cross = [], []
for block in self.model.decoder.blocks:
# Group self and cross values
present_self.append(kv_cache[block.attn.key])
present_self.append(kv_cache[block.attn.value])
if past_key_values is None:
# Return present_self_* and present_cross_* for decoder-init
present_cross.append(kv_cache[block.cross_attn.key])
present_cross.append(kv_cache[block.cross_attn.value])
# Convert present KV caches (BxSxD --> BxSxNxH --> BxNxSxH) after OpenAI's forward pass
present_self = [
present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
for present_kv in present_self
]
present_cross = [
present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
for present_kv in present_cross
]
# Remove OpenAI's hooks since they can persist after this function completes
for hook in hooks:
hook.remove()
if past_key_values is None:
# Return present_self_* and present_cross_* for decoder-init
present_key_values = PastKeyValuesHelper.group_by_layer(
present_self + present_cross, len(present_self) // 2
)
return logits, present_key_values
# Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
return logits, present_self
def forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
past_key_values: list[tuple[torch.Tensor]] | None = None,
):
if self.model_impl == "openai":
return self.oai_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
return self.hf_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
def input_names(self):
if self.first_pass:
input_names = ["input_ids", "encoder_hidden_states"]
else:
input_names = [
"input_ids",
"encoder_hidden_states",
*list(
chain.from_iterable(
(f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}")
for i in range(self.config.decoder_layers)
)
),
]
return input_names
def output_names(self):
if self.first_pass:
output_names = [
"logits",
*list(
chain.from_iterable(
(
f"present_key_self_{i}",
f"present_value_self_{i}",
f"present_key_cross_{i}",
f"present_value_cross_{i}",
)
for i in range(self.config.decoder_layers)
)
),
]
else:
output_names = [
"logits",
*list(
chain.from_iterable(
(f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers)
)
),
]
return output_names
def dynamic_axes(self, input_names, output_names):
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
if "input_ids" in dynamic_axes and not self.no_beam_search_op:
# Set dynamic axes for `input_ids` when using beam search op to {0: "batch_size"} only
del dynamic_axes["input_ids"][1]
return dynamic_axes
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
inputs = get_sample_decoder_inputs(
self.config,
self.device,
batch_size=2,
past_sequence_length=(0 if self.first_pass else 6),
sequence_length=(6 if self.first_pass else 1),
use_fp16=use_fp16_inputs,
use_int32=use_int32_inputs,
)
if return_dict:
if self.first_pass:
del inputs["past_key_values"]
return inputs
if self.first_pass:
return (
inputs["decoder_input_ids"],
inputs["encoder_hidden_states"],
)
return (
inputs["decoder_input_ids"],
inputs["encoder_hidden_states"],
inputs["past_key_values"],
)
def fix_key_value_cache_dims(self, io: ValueInfoProto, is_cross: bool = False, is_output: bool = False):
# Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
# and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
num_heads = io.type.tensor_type.shape.dim[1]
if "_dim_" in num_heads.dim_param:
num_heads.Clear()
num_heads.dim_value = self.num_heads
sequence_length = io.type.tensor_type.shape.dim[2]
if "_dim_" in sequence_length.dim_param:
sequence_length.Clear()
if is_cross:
sequence_length.dim_value = self.max_source_positions
else:
sequence_length.dim_param = "total_sequence_length" if is_output else "past_sequence_length"
head_size = io.type.tensor_type.shape.dim[3]
if "_dim_" in head_size.dim_param:
head_size.Clear()
head_size.dim_value = self.head_size
return io
def fix_io(self, io_list: RepeatedCompositeFieldContainer, is_output: bool = False):
# Fix order of inputs/outputs and each dim_value of input/output
reordered_io = []
self_attn_kv_caches = []
cross_attn_kv_caches = []
for io in io_list:
if "past" not in io.name and "present" not in io.name:
reordered_io.append(io)
elif "self" in io.name:
# Self attention KV caches
new_io = self.fix_key_value_cache_dims(io, is_cross=False, is_output=is_output)
if self.no_beam_search_op:
reordered_io.append(new_io)
else:
self_attn_kv_caches.append(new_io)
else:
# Cross attention KV caches
new_io = self.fix_key_value_cache_dims(io, is_cross=True, is_output=is_output)
if self.no_beam_search_op:
reordered_io.append(new_io)
else:
cross_attn_kv_caches.append(new_io)
if not self.no_beam_search_op:
reordered_io += self_attn_kv_caches + cross_attn_kv_caches
return reordered_io
def fix_inputs_and_outputs(self, model: ModelProto):
# ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
# We now change the dim_values to the correct one.
reordered_inputs = self.fix_io(model.graph.input, is_output=False)
while len(model.graph.input) > 0:
model.graph.input.pop()
model.graph.input.extend(reordered_inputs)
reordered_outputs = self.fix_io(model.graph.output, is_output=True)
while len(model.graph.output) > 0:
model.graph.output.pop()
model.graph.output.extend(reordered_outputs)
return model
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
if self.model_impl == "openai" and use_fp16_inputs:
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
# float32 to float16 since exported model already has float16 weights everywhere
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
# when computing LayerNorm.
#
# Reference:
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
model = convert_float_to_float16(model)
return model
def export_onnx(
self,
onnx_model_path: str,
provider: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_fp16_inputs: bool = False,
use_int32_inputs: bool = True,
use_encoder_hidden_states: bool = False,
use_kv_cache_inputs: bool = True,
):
"""Export decoder to ONNX
Args:
onnx_model_path (str): path to save ONNX model
provider (str): provider to use for verifying parity on ONNX model
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_fp16_inputs (bool, optional): use float16 inputs for the KV caches. Defaults to False.
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
use_encoder_hidden_states (bool, optional): use encoder_hidden_states as model input for decoder-init/decoder-without-past models. Defaults to False.
use_kv_cache_inputs (bool, optional): use KV caches as model inputs for decoder-with-past models. Defaults to True.
"""
# Shape of decoder's tensors:
# Required Inputs:
# decoder_input_ids: (batch_size, sequence_length)
# Optional Inputs:
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
# past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
# past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
# Outputs:
# logits: (batch_size, sequence_length, vocab_size)
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
# For the first pass through the decoder (i.e. decoder-init/decoder-without-past)
self.first_pass = use_encoder_hidden_states and not use_kv_cache_inputs
# For subsequent passes through the decoder (i.e. decoder-with-past)
self.later_pass = not use_encoder_hidden_states and use_kv_cache_inputs
assert self.first_pass or self.later_pass, (
"Only one of `use_encoder_hidden_states` and `use_kv_cache_inputs` can be true at once."
)
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
input_names = self.input_names()
output_names = self.output_names()
dynamic_axes = self.dynamic_axes(input_names, output_names)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
torch.onnx.export(
self,
args=inputs,
f=out_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True,
verbose=verbose,
)
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
model = self.fix_inputs_and_outputs(model)
model = self.fix_layernorm_weights(model, use_fp16_inputs)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=use_external_data_format,
all_tensors_to_one_file=True,
)
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
def verify_onnx(
self,
onnx_model_path: str,
provider: str,
use_fp16_inputs: bool,
use_int32_inputs: bool,
):
"""Verify ONNX model outputs and PyTorch model outputs match
Args:
onnx_model_path (str): path to save ONNX model
provider (str): execution provider for ONNX model
use_fp16_inputs (bool, optional): use float16 inputs for the KV caches
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
"""
# Shape of decoder's tensors:
# Required Inputs:
# decoder_input_ids: (batch_size, sequence_length)
# Optional Inputs:
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
# past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
# past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
# Outputs:
# logits: (batch_size, sequence_length, vocab_size)
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
# Run PyTorch model
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
pt_outputs = []
if self.first_pass:
out = self.forward(**inputs)
pt_outputs.append(out[0].detach().cpu().numpy())
for present_key_value_layer in out[1]:
for present_key_value in present_key_value_layer:
pt_outputs.append(present_key_value.detach().cpu().numpy())
else:
out = self.forward(**inputs)
pt_outputs.append(out[0].detach().cpu().numpy())
for present_self_key_value in out[1]:
pt_outputs.append(present_self_key_value.detach().cpu().numpy())
# Run ONNX model
sess = InferenceSession(onnx_model_path, providers=[provider])
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
# Calculate output difference
try:
for i, output_name in enumerate(self.output_names()):
diff = np.abs(pt_outputs[i] - ort_outputs[i])
logger.warning(f"Comparing {output_name}...")
logger.warning(f"Max diff: {np.max(diff)}")
except: # noqa: E722
pass
@@ -0,0 +1,164 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from pathlib import Path
import numpy as np
import onnx
import torch
from float16 import convert_float_to_float16
from onnx import ModelProto
from onnx_model import OnnxModel
from transformers import WhisperConfig
from whisper_inputs import get_model_dynamic_axes, get_sample_encoder_inputs
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class WhisperEncoder(torch.nn.Module):
"""Whisper encoder component"""
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str):
super().__init__()
self.config = config
self.device = model.device
self.model_impl = model_impl
self.encoder = model.encoder if model_impl == "openai" else model.model.encoder
def forward(self, audio_features: torch.Tensor):
outputs = self.encoder(audio_features)
return outputs if self.model_impl == "openai" else outputs.last_hidden_state
def input_names(self):
input_names = ["audio_features"]
return input_names
def output_names(self):
output_names = ["encoder_hidden_states"]
return output_names
def dynamic_axes(self, input_names, output_names):
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
return dynamic_axes
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
if self.model_impl == "openai" and use_fp16_inputs:
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
# float32 to float16 since exported model already has float16 weights everywhere
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
# when computing LayerNorm.
#
# Reference:
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
model = convert_float_to_float16(model)
return model
def export_onnx(
self,
onnx_model_path: str,
provider: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_fp16_inputs: bool = False,
):
"""Export encoder to ONNX
Args:
onnx_model_path (str): path to save ONNX model
provider (str): provider to use for verifying parity on ONNX model
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
"""
# Shape of encoder's tensors:
# Inputs:
# audio_features: (batch_size, num_mels, num_frames)
# Outputs:
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
inputs = get_sample_encoder_inputs(
self.config,
self.device,
batch_size=2,
use_fp16=use_fp16_inputs,
)
input_names = self.input_names()
output_names = self.output_names()
dynamic_axes = self.dynamic_axes(input_names, output_names)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
torch.onnx.export(
self,
args=(inputs["audio_features"]),
f=out_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True,
verbose=verbose,
)
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
model = self.fix_layernorm_weights(model, use_fp16_inputs)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=use_external_data_format,
all_tensors_to_one_file=True,
)
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs)
def verify_onnx(
self,
onnx_model_path: str,
provider: str,
use_fp16_inputs: bool,
):
"""Verify ONNX model outputs and PyTorch model outputs match
Args:
onnx_model_path (str): path to save ONNX model
provider (str): execution provider for ONNX model
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
"""
# Shape of encoder's tensors:
# Inputs:
# audio_features: (batch_size, num_mels, num_frames)
# Outputs:
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
inputs = get_sample_encoder_inputs(
self.config,
self.device,
batch_size=2,
use_fp16=use_fp16_inputs,
)
# Run PyTorch model
pt_outputs = self.forward(inputs["audio_features"]).detach().cpu().numpy()
# Run ONNX model
sess = InferenceSession(onnx_model_path, providers=[provider])
ort_outputs = sess.run(None, {"audio_features": inputs["audio_features"].detach().cpu().numpy()})[0]
# Calculate output difference
diff = np.abs(pt_outputs - ort_outputs)
logger.warning("Comparing encoder_hidden_states...")
logger.warning(f"Max diff: {np.max(diff)}")
@@ -0,0 +1,371 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
from itertools import chain
from pathlib import Path
import numpy as np
import onnx
import torch
from float16 import convert_float_to_float16
from onnx import ModelProto, ValueInfoProto
from onnx_model import OnnxModel
from transformers import WhisperConfig
from whisper_decoder import WhisperDecoder
from whisper_encoder import WhisperEncoder
from whisper_inputs import (
convert_inputs_for_ort,
get_model_dynamic_axes,
get_sample_encoder_decoder_init_inputs,
group_past_key_values,
)
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
class WhisperEncoderDecoderInit(torch.nn.Module):
"""Whisper encoder component + first pass through Whisper decoder component to initialize KV caches"""
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
super().__init__()
self.config = config
self.device = model.device
self.model_impl = model_impl
self.no_beam_search_op = no_beam_search_op
self.encoder = WhisperEncoder(config, model, model_impl)
self.decoder = WhisperDecoder(config, model, model_impl, no_beam_search_op)
self.max_source_positions = self.config.max_source_positions
self.num_heads = self.config.decoder_attention_heads
self.head_size = self.config.d_model // self.num_heads
def hf_forward_for_beam_search_op(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor):
encoder_hidden_states = self.encoder(audio_features)
logits, present_key_values = self.decoder(decoder_input_ids, encoder_hidden_states)
return logits, encoder_hidden_states, present_key_values
def hf_forward_for_no_beam_search_op(self, audio_features: torch.Tensor):
encoder_hidden_states = self.encoder(audio_features)
# Get cross attention KV caches and return them for this model
# We do this because these MatMuls are only run once before their outputs are being re-used in the decoder
present_cross_attention_key_value_caches = []
for layer in self.decoder.decoder.layers:
cross_attn_key_cache = (
layer.encoder_attn.k_proj(encoder_hidden_states)
.view(-1, self.max_source_positions, self.num_heads, self.head_size)
.transpose(1, 2)
)
cross_attn_value_cache = (
layer.encoder_attn.v_proj(encoder_hidden_states)
.view(-1, self.max_source_positions, self.num_heads, self.head_size)
.transpose(1, 2)
)
present_cross_attention_key_value_caches.append(cross_attn_key_cache)
present_cross_attention_key_value_caches.append(cross_attn_value_cache)
return encoder_hidden_states, present_cross_attention_key_value_caches
def oai_forward_for_beam_search_op(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor):
encoder_hidden_states = self.encoder(audio_features)
logits, present_key_values = self.decoder(decoder_input_ids, encoder_hidden_states)
return logits, encoder_hidden_states, present_key_values
def oai_forward_for_no_beam_search_op(self, audio_features: torch.Tensor):
encoder_hidden_states = self.encoder(audio_features)
# Get cross attention KV caches and return them for this model
# We do this because these MatMuls are only run once before their outputs are being re-used in the decoder
present_cross_attention_key_value_caches = []
for block in self.decoder.model.decoder.blocks:
cross_attn_key_cache = (
block.cross_attn.key(encoder_hidden_states)
.view(-1, self.max_source_positions, self.num_heads, self.head_size)
.transpose(1, 2)
)
cross_attn_value_cache = (
block.cross_attn.value(encoder_hidden_states)
.view(-1, self.max_source_positions, self.num_heads, self.head_size)
.transpose(1, 2)
)
present_cross_attention_key_value_caches.append(cross_attn_key_cache)
present_cross_attention_key_value_caches.append(cross_attn_value_cache)
return encoder_hidden_states, present_cross_attention_key_value_caches
def forward(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor | None = None):
if self.model_impl == "openai":
if self.no_beam_search_op:
return self.oai_forward_for_no_beam_search_op(audio_features)
return self.oai_forward_for_beam_search_op(audio_features, decoder_input_ids)
# Hugging Face implementation
if self.no_beam_search_op:
return self.hf_forward_for_no_beam_search_op(audio_features)
return self.hf_forward_for_beam_search_op(audio_features, decoder_input_ids)
def input_names(self):
if self.no_beam_search_op:
input_names = ["audio_features"]
else:
input_names = ["encoder_input_ids", "decoder_input_ids"]
return input_names
def output_names(self):
if self.no_beam_search_op:
output_names = [
"encoder_hidden_states",
*list(
chain.from_iterable(
(f"present_key_cross_{i}", f"present_value_cross_{i}")
for i in range(self.config.decoder_layers)
)
),
]
else:
output_names = [
"logits",
"encoder_hidden_states",
*list(
chain.from_iterable(
(
f"present_key_self_{i}",
f"present_value_self_{i}",
f"present_key_cross_{i}",
f"present_value_cross_{i}",
)
for i in range(self.config.decoder_layers)
)
),
]
return output_names
def dynamic_axes(self, input_names, output_names):
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
return dynamic_axes
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
inputs = get_sample_encoder_decoder_init_inputs(
self.config,
self.device,
batch_size=2,
decoder_sequence_length=6,
use_fp16=use_fp16_inputs,
use_int32=use_int32_inputs,
)
if return_dict:
if self.no_beam_search_op:
del inputs["decoder_input_ids"]
return inputs
if self.no_beam_search_op:
return (inputs["audio_features"],)
return (
inputs["audio_features"],
inputs["decoder_input_ids"],
)
def fix_key_value_cache_dims(self, output: ValueInfoProto, is_cross: bool = False):
# Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
# and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
num_heads = output.type.tensor_type.shape.dim[1]
if "_dim_" in num_heads.dim_param:
num_heads.Clear()
num_heads.dim_value = self.num_heads
sequence_length = output.type.tensor_type.shape.dim[2]
if "_dim_" in sequence_length.dim_param:
sequence_length.Clear()
if is_cross:
sequence_length.dim_value = self.max_source_positions
else:
sequence_length.dim_param = "total_sequence_length"
head_size = output.type.tensor_type.shape.dim[3]
if "_dim_" in head_size.dim_param:
head_size.Clear()
head_size.dim_value = self.head_size
return output
def fix_outputs(self, model: ModelProto):
# ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
# We now change the dim_values to the correct one.
reordered_outputs = []
self_attn_kv_caches = []
cross_attn_kv_caches = []
for output in model.graph.output:
if "present" not in output.name:
reordered_outputs.append(output)
elif "self" in output.name:
# Self attention KV caches
new_output = self.fix_key_value_cache_dims(output, is_cross=False)
if self.no_beam_search_op:
reordered_outputs.append(new_output)
else:
self_attn_kv_caches.append(new_output)
else:
# Cross attention KV caches
new_output = self.fix_key_value_cache_dims(output, is_cross=True)
if self.no_beam_search_op:
reordered_outputs.append(new_output)
else:
cross_attn_kv_caches.append(new_output)
if not self.no_beam_search_op:
reordered_outputs += self_attn_kv_caches + cross_attn_kv_caches
while len(model.graph.output) > 0:
model.graph.output.pop()
model.graph.output.extend(reordered_outputs)
return model
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
if self.model_impl == "openai" and use_fp16_inputs:
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
# float32 to float16 since exported model already has float16 weights everywhere
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
# when computing LayerNorm.
#
# Reference:
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
model = convert_float_to_float16(model)
return model
def export_onnx(
self,
onnx_model_path: str,
provider: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_fp16_inputs: bool = False,
use_int32_inputs: bool = True,
):
"""Export encoder-decoder-init to ONNX
Args:
onnx_model_path (str): path to save ONNX model
provider (str): provider to use for verifying parity on ONNX model
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
"""
# Shape of encoder's tensors:
# Inputs:
# audio_features: (batch_size, num_mels, num_frames)
# Outputs:
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
# Shape of decoder's tensors:
# Inputs:
# decoder_input_ids: (batch_size, sequence_length)
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
# Outputs:
# logits: (batch_size, sequence_length, vocab_size)
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
input_names = self.input_names()
output_names = self.output_names()
dynamic_axes = self.dynamic_axes(input_names, output_names)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
torch.onnx.export(
self,
args=inputs,
f=out_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True,
verbose=verbose,
)
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
model = self.fix_outputs(model)
model = self.fix_layernorm_weights(model, use_fp16_inputs)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=use_external_data_format,
all_tensors_to_one_file=True,
)
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
def verify_onnx(
self,
onnx_model_path: str,
provider: str,
use_fp16_inputs: bool,
use_int32_inputs: bool,
):
"""Verify ONNX model outputs and PyTorch model outputs match
Args:
onnx_model_path (str): path to save ONNX model
provider (str): execution provider for ONNX model
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
"""
# Shape of encoder's tensors:
# Inputs:
# audio_features: (batch_size, num_mels, num_frames)
# Outputs:
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
# Shape of decoder's tensors:
# Inputs:
# decoder_input_ids: (batch_size, sequence_length)
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
# Outputs:
# logits: (batch_size, sequence_length, vocab_size)
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
# Run PyTorch model
pt_outputs = []
if self.no_beam_search_op:
out = self.forward(**inputs)
pt_outputs.append(out[0].detach().cpu().numpy())
for present_cross_attn_cache in out[1]:
pt_outputs.append(present_cross_attn_cache.detach().cpu().numpy())
else:
out = self.forward(**inputs)
pt_outputs.append(out[0].detach().cpu().numpy())
pt_outputs.append(out[1].detach().cpu().numpy())
(self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(out[2])
pt_outputs.extend([self_attn_kv_cache.detach().cpu().numpy() for self_attn_kv_cache in self_attn_kv_caches])
pt_outputs.extend(
[cross_attn_kv_cache.detach().cpu().numpy() for cross_attn_kv_cache in cross_attn_kv_caches]
)
# Run ONNX model
sess = InferenceSession(onnx_model_path, providers=[provider])
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
# Calculate output difference
for i, output_name in enumerate(self.output_names()):
diff = np.abs(pt_outputs[i] - ort_outputs[i])
logger.warning(f"Comparing {output_name}...")
logger.warning(f"Max diff: {np.max(diff)}")
@@ -0,0 +1,380 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import numpy as np
import torch
from transformers import WhisperConfig
from onnxruntime import InferenceSession
logger = logging.getLogger(__name__)
# Create audio_features for encoder
# Shape is (batch_size, feature_size, sequence_length) = (batch_size, num_mel_filters, num_frames)
# where num_mel_filters is a model attribute and num_frames = (chunk_length * sample_rate) // hop_length.
#
# Hard-coded audio hyperparameters:
# SAMPLE_RATE = 16000
# N_FFT = 400
# HOP_LENGTH = 160
# CHUNK_LENGTH = 30 (i.e. 30-second chunk of audio)
# N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE = 30 * 16000 = 480000 (i.e. 480,000 samples in a 30-second chunk of audio)
# N_FRAMES = N_SAMPLES // HOP_LENGTH = 480000 // 160 = 3000 (i.e. 3000 frames in a mel spectrogram input)
#
# N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 = 160 * 2 = 320
# FRAMES_PER_TOKEN = SAMPLE_RATE // HOP_LENGTH = 16000 // 160 = 100 (i.e. 10 ms per audio frame)
# TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN = 16000 // 320 = 50 (i.e. 20 ms per audio token)
def get_sample_audio_features(
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int = 3000,
use_fp16: bool = False,
):
torch_dtype = torch.float16 if use_fp16 else torch.float32
audio_features = torch.randn(batch_size, config.num_mel_bins, sequence_length, device=device, dtype=torch_dtype)
return audio_features
# Create input_ids for decoder
# Shape is (batch_size, sequence_length) where sequence_length is the initial decoder sequence length
def get_sample_decoder_input_ids(
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int,
use_int32: bool = True,
):
torch_dtype = torch.int32 if use_int32 else torch.int64
decoder_input_ids = torch.randint(
low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=device, dtype=torch_dtype
)
return decoder_input_ids
# Create encoder_hidden_states for decoder-init
# Shape is (batch_size, num_frames // 2, hidden_size)
def get_sample_encoder_hidden_states(
config: WhisperConfig,
device: torch.device,
batch_size: int,
use_fp16: bool = False,
):
torch_dtype = torch.float16 if use_fp16 else torch.float32
encoder_hidden_states = torch.randn(
batch_size, config.max_source_positions, config.d_model, device=device, dtype=torch_dtype
)
return encoder_hidden_states
# Create past_key_values
# Self-attention KV caches are of shape (batch_size, num_heads, past_sequence_length, head_size)
# Cross-attention KV caches are of shape (batch_size, num_heads, num_frames // 2, head_size)
def get_sample_past_key_values(
config: WhisperConfig,
device: torch.device,
batch_size: int,
past_seq_len: int,
use_fp16: bool = False,
):
num_heads = config.decoder_attention_heads
head_size = config.d_model // num_heads
max_source_positions = (
config.max_source_positions
) # equal to num_frames // 2 = encoder's sequence_length // 2 = 3000 // 2 = 1500
torch_dtype = torch.float16 if use_fp16 else torch.float32
self_attention_kv_caches = [
(
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
)
for _ in range(config.decoder_layers)
]
cross_attention_kv_caches = [
(
torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
)
for _ in range(config.decoder_layers)
]
return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches)
# Flatten KV caches into pairs-of-4 where each pair is defined as:
# (self_attn_key_cache, self_attn_value_cache, cross_attn_key_cache, cross_attn_value_cache)
def flatten_past_key_values(
self_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
cross_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
):
past_key_values = []
for (self_k_cache, self_v_cache), (cross_k_cache, cross_v_cache) in zip(
self_attn_kv_caches, cross_attn_kv_caches, strict=False
):
layer_kv_caches = (self_k_cache, self_v_cache, cross_k_cache, cross_v_cache)
past_key_values.append(layer_kv_caches)
return past_key_values
# Group KV caches into two 1D lists where one list contains the self attention KV caches and
# one list contains the cross attention KV caches
def group_past_key_values(
kv_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
):
self_attn_kv_caches, cross_attn_kv_caches = [], []
for self_k_cache, self_v_cache, cross_k_cache, cross_v_cache in kv_caches:
self_attn_kv_caches.append(self_k_cache)
self_attn_kv_caches.append(self_v_cache)
cross_attn_kv_caches.append(cross_k_cache)
cross_attn_kv_caches.append(cross_v_cache)
return self_attn_kv_caches, cross_attn_kv_caches
# Create alignment heads for timestamps
# Shape is (num_alignment_heads, 2)
def get_sample_alignment_heads(
config: WhisperConfig,
device: torch.device,
num_alignment_heads: int = 6,
use_int32: bool = True,
):
torch_dtype = torch.int32 if use_int32 else torch.int64
alignment_heads = torch.ones((num_alignment_heads, 2), device=device, dtype=torch_dtype)
return alignment_heads
# Create length of start-of-transcription sequence for timestamps
# Shape is (1)
def get_sample_sot_sequence_length(
device: torch.device,
sot_sequence_length: int,
use_int32: bool = False,
):
torch_dtype = torch.int32 if use_int32 else torch.int64
sot_length = torch.tensor([sot_sequence_length], device=device, dtype=torch_dtype)
return sot_length
# Create segment length for timestamps
# Shape is (1)
def get_sample_segment_length(
device: torch.device,
segment_length: int,
use_int32: bool = False,
):
torch_dtype = torch.int32 if use_int32 else torch.int64
segment_size = torch.tensor([segment_length], device=device, dtype=torch_dtype)
return segment_size
# Create QKs for timestamps
# Shape is (batch_size, num_heads, sequence_length, num_frames // 2)
def get_sample_QKs( # noqa: N802
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int,
use_fp16: bool = False,
):
num_heads = config.decoder_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
QKs = [ # noqa: N806
torch.rand(
batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype
)
for _ in range(config.decoder_layers)
]
return QKs
# Create inputs for encoder component of Whisper
def get_sample_encoder_inputs(
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int = 3000,
use_fp16: bool = False,
):
audio_features = get_sample_audio_features(config, device, batch_size, sequence_length, use_fp16)
return {"audio_features": audio_features}
# Create inputs for encoder component + first pass through decoder component of Whisper
def get_sample_encoder_decoder_init_inputs(
config: WhisperConfig,
device: torch.device,
batch_size: int,
decoder_sequence_length: int,
encoder_sequence_length: int = 3000,
use_fp16: bool = False,
use_int32: bool = True,
):
audio_features = get_sample_audio_features(config, device, batch_size, encoder_sequence_length, use_fp16)
decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, decoder_sequence_length, use_int32)
return {"audio_features": audio_features, "decoder_input_ids": decoder_input_ids}
# Create inputs for decoder component of Whisper
# Inputs for first pass through the decoder (i.e. decoder-init): decoder_input_ids, encoder_hidden_states
# Inputs for subsequent passes through the decoder (i.e. decoder-with-past): decoder_input_ids, past_key_values
def get_sample_decoder_inputs(
config: WhisperConfig,
device: torch.device,
batch_size: int,
past_sequence_length: int,
sequence_length: int,
use_fp16: bool = False,
use_int32: bool = True,
):
decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, sequence_length, use_int32)
encoder_hidden_states = get_sample_encoder_hidden_states(config, device, batch_size, use_fp16)
past_key_values = get_sample_past_key_values(config, device, batch_size, past_sequence_length, use_fp16)
return {
"decoder_input_ids": decoder_input_ids,
"encoder_hidden_states": encoder_hidden_states,
"past_key_values": past_key_values,
}
# Create inputs for timestamps component of Whisper
def get_sample_jump_times_inputs(
config: WhisperConfig,
device: torch.device,
batch_size: int,
sequence_length: int,
num_alignment_heads: int,
sot_sequence_length: int,
segment_length: int,
use_fp16: bool = False,
use_int32: bool = True,
):
alignment_heads = get_sample_alignment_heads(config, device, num_alignment_heads, use_int32)
# lengths need to be int64 because subsequent 'Slice' ops only take int64 inputs
sot_sequence_length = get_sample_sot_sequence_length(device, sot_sequence_length)
segment_length = get_sample_segment_length(device, segment_length)
QKs = get_sample_QKs(config, device, batch_size, sequence_length, use_fp16) # noqa: N806
return {
"alignment_heads": alignment_heads,
"sot_sequence_length": sot_sequence_length,
"segment_length": segment_length,
"QKs": QKs,
}
# Convert PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
inputs: dict,
model: InferenceSession,
):
self_attn_kv_caches, cross_attn_kv_caches = None, None
batch_size, num_heads, past_seq_len, head_size = 0, 0, 0, 0
num_beams, max_seq_len = 1, 448
if "past_key_values" in inputs:
(self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(inputs["past_key_values"])
batch_size, num_heads, past_seq_len, head_size = self_attn_kv_caches[0].shape
ort_inputs = {}
model_inputs = list(map(lambda i: i.name, model.get_inputs())) # noqa: C417
use_buffer_sharing = "cache_indirection" in model_inputs
for name in model_inputs:
if name in {"audio_features", "encoder_input_ids"}:
# Encoder input
ort_inputs[name] = inputs["audio_features"].detach().cpu().numpy()
elif name == "encoder_hidden_states":
# Encoder output
ort_inputs[name] = inputs["encoder_hidden_states"].detach().cpu().numpy()
elif name in {"decoder_input_ids", "input_ids"}:
# Decoder input
ort_inputs[name] = inputs["decoder_input_ids"].detach().cpu().numpy()
elif "past_key_self" in name or "past_value_self" in name:
# Decoder input
orig_kv_cache = self_attn_kv_caches.pop(0).detach().cpu().numpy()
if use_buffer_sharing:
new_kv_cache = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=orig_kv_cache.dtype)
new_kv_cache[:batch_size, :num_heads, :past_seq_len, :head_size] = orig_kv_cache
ort_inputs[name] = new_kv_cache
else:
ort_inputs[name] = orig_kv_cache
elif "past_key_cross" in name or "past_value_cross" in name:
# Decoder input
orig_kv_cache = cross_attn_kv_caches.pop(0).detach().cpu().numpy()
ort_inputs[name] = orig_kv_cache
elif name == "past_sequence_length":
# Decoder input
ort_inputs[name] = np.array([past_seq_len], dtype=np.int32)
elif name == "cache_indirection":
# Decoder input
ort_inputs[name] = np.zeros((batch_size, num_beams, max_seq_len), dtype=np.int32)
elif name == "alignment_heads":
# Jump times input
ort_inputs[name] = inputs["alignment_heads"].detach().cpu().numpy()
elif name == "sot_sequence_length":
# Jump times input
ort_inputs[name] = inputs["sot_sequence_length"].detach().cpu().numpy()
elif name == "segment_length":
# Jump times input
ort_inputs[name] = inputs["segment_length"].detach().cpu().numpy()
elif "cross_qk" in name:
# Jump times input
ort_inputs[name] = inputs["QKs"].pop(0).detach().cpu().numpy()
else:
raise ValueError(f"Unknown name not recognized: {name}")
return ort_inputs
# Get dynamic axes for all inputs and outputs to the model
def get_model_dynamic_axes(
config: WhisperConfig,
input_names: list[str],
output_names: list[str],
):
dynamic_axes = {}
for name in input_names + output_names:
if name in {"audio_features", "encoder_input_ids"}:
# shape is (batch_size, num_mels, num_frames)
dynamic_axes[name] = {0: "batch_size"}
elif name in {"input_ids", "decoder_input_ids"}:
# shape is (batch_size, sequence_length)
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
elif name == "alignment_heads":
# shape is (num_alignment_heads, 2)
dynamic_axes[name] = {0: "num_alignment_heads"}
elif name in {"sot_sequence_length", "segment_length"}:
# shape is (1)
pass
elif name == "logits":
# shape is (batch_size, sequence_length, vocab_size)
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
elif name == "encoder_hidden_states":
# shape is (batch_size, num_frames // 2, hidden_size)
dynamic_axes[name] = {0: "batch_size"}
elif "past_key_self" in name or "past_value_self" in name:
# shape is (batch_size, num_heads, past_sequence_length, head_size)
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
elif "present_key_self" in name or "present_value_self" in name:
# shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size),
# which is equal to (batch_size, num_heads, total_sequence_length, head_size)
dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
elif (
"past_key_cross" in name
or "past_value_cross" in name
or "present_key_cross" in name
or "present_value_cross" in name
):
# shape is (batch_size, num_heads, num_frames // 2, head_size)
dynamic_axes[name] = {0: "batch_size"}
elif "cross_qk" in name:
# shape is (batch_size, num_heads, source_sequence_length, target_sequence_length)
dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
elif "jump_times" in name:
# shape is (batch_size, max_length)
dynamic_axes[name] = {0: "batch_size", 1: "max_length"}
else:
raise Exception(f"Unknown input or output name found: {name}")
return dynamic_axes
@@ -0,0 +1,477 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import logging
import os
import tempfile
import textwrap
from pathlib import Path
import numpy as np
import onnx
import torch
import torch.nn.functional as F
import torch.utils.cpp_extension
from onnx_model import OnnxModel
from transformers import WhisperConfig
from whisper_inputs import convert_inputs_for_ort, get_model_dynamic_axes, get_sample_jump_times_inputs
from onnxruntime import InferenceSession
from onnxruntime.tools import pytorch_export_contrib_ops
logger = logging.getLogger(__name__)
##################################################
# Functions that have to be outside of the class
# for torch.jit.script_if_tracing to work
##################################################
@torch.jit.script_if_tracing
def index_QKs(alignment_heads: torch.Tensor, QKs: list[torch.Tensor]): # noqa: N802
"""
Compute the following to get stacked QK tensor that has been indexed for the desired attention heads:
weights = torch.stack([QKs[_l][:, _h] for _l, _h in alignment_heads], dim=1)
"""
indexed_QKs = [] # noqa: N806
for pair in alignment_heads:
# Each QK is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
# The `QKs[_l]` selects the right QK from the list of QKs
# The `QKs[_l][:, _h]` selects the right attention heads from the chosen QK. The `:` is to do this for the batch dim.
#
# PyTorch:
# QKs[_l] is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
# QKs[_l][:, _h] is of shape (batch_size, sequence_length, num_frames // 2)
#
# ONNX:
# QKs[_l] is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
# QKs[_l][:, _h] is of shape (batch_size, 1, sequence_length, num_frames // 2) because
# the `[:, _h]` operation maps to a Gather op and that op does not reduce dimensions
_l, _h = pair[0], pair[1]
indexed_QKs.append(QKs[_l][:, _h])
# PyTorch:
# torch.stack will return a tensor of shape (batch_size, num_alignment_heads, sequence_length, num_frames // 2).
#
# ONNX:
# torch.stack will return a tensor of shape (batch_size, num_alignment_heads, 1, sequence_length, num_frames // 2)
# because the Gather op does not reduce dimensions. To remove the unneeded dimension, torch.squeeze with a specified
# dim (dim = 2) is added. The torch.squeeze op with a specified dim only runs if the specified dim has a size of 1.
# Since the dim won't be of size 1 in the PyTorch tensor but it is of size 1 in the ONNX tensor, it will be a no-op
# in PyTorch and an op in ONNX. Thus, the Squeeze op will only affect the ONNX model.
weights = torch.stack(indexed_QKs, dim=1)
weights = torch.squeeze(weights, dim=2)
return weights
def jump_timings(text_indices, time_indices):
"""
Calculate jump times from text_indices and time_indices where
text_indices and time_indices are both 1d vectors
"""
TOKENS_PER_SECOND = 50.0 # noqa: N806
diff = text_indices[1:] - text_indices[:-1]
padding = torch.tensor([1], dtype=torch.int32)
jumps = torch.cat((padding, diff)).to(torch.bool)
jump_times = time_indices[jumps].to(torch.float) / TOKENS_PER_SECOND
return jump_times
def padded_jump_from_dtw(matrix_2d: torch.Tensor, max_length: torch.Tensor):
"""
Run Dynamic Time Warping (DTW) on batched tensor
"""
trace = torch.ops.onnxruntime.DynamicTimeWarping(matrix_2d)
text_indices = trace[0, :]
time_indices = trace[1, :]
jump_times = jump_timings(text_indices, time_indices)
return F.pad(jump_times, [0, int((max_length - jump_times.size(-1)).item())], mode="constant", value=-1.0)
@torch.jit.script_if_tracing
def batch_jump_times(matrix: torch.Tensor, max_decoded_length: torch.Tensor):
"""
Compute the following to calculate jump times for all batches:
batched_jump_times = torch.stack([self.padded_jump_from_dtw(matrix[b], max_decoded_length) for b in range(matrix.size(0))])
"""
list_of_jump_times = []
for b in range(matrix.size(0)):
jump_times = padded_jump_from_dtw(matrix[b], max_decoded_length)
list_of_jump_times.append(jump_times)
batched_jump_times = torch.stack(list_of_jump_times)
return batched_jump_times
class WhisperJumpTimes(torch.nn.Module):
"""Whisper jump times component"""
def __init__(self, config: WhisperConfig, device: torch.device, cache_dir: str | os.PathLike):
super().__init__()
self.config = config
self.device = device
self.cache_dir = cache_dir
self.filter_width = 7
self.qk_scale = 1.0
def median_filter(self, weights: torch.Tensor):
"""
Apply a median filter of width `filter_width` along the last dimension of `weights`
"""
pad_width = self.filter_width // 2
x = F.pad(weights, (pad_width, pad_width, 0, 0), mode="reflect")
x_unfolded = torch.ops.onnxruntime.UnfoldTensor(x, -1, self.filter_width, 1)
result = torch.select(x_unfolded.sort()[0], dim=-1, index=pad_width)
return result
def forward(
self,
alignment_heads: torch.Tensor,
sot_sequence_length: torch.Tensor,
segment_length: torch.Tensor,
QKs: list[torch.Tensor],
):
# Get stacked QKs tensor
weights = index_QKs(alignment_heads, QKs)
weights = weights[:, :, : segment_length // 2]
weights = weights.to(torch.float32)
weights = (weights * self.qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = self.median_filter(weights)
matrix = torch.mean(weights, 1)
matrix = -matrix[:, sot_sequence_length:-1]
max_decoded_length = torch.tensor([matrix.size(1)], dtype=torch.int64)
batched_jump_times = batch_jump_times(matrix, max_decoded_length)
return batched_jump_times
def input_names(self):
input_names = [
"alignment_heads",
"sot_sequence_length",
"segment_length",
*[f"cross_qk_{i}" for i in range(self.config.decoder_layers)],
]
return input_names
def output_names(self):
output_names = ["jump_times"]
return output_names
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
inputs = get_sample_jump_times_inputs(
self.config,
self.device,
batch_size=2,
sequence_length=8,
num_alignment_heads=6,
sot_sequence_length=3,
segment_length=1332,
use_fp16=use_fp16_inputs,
use_int32=use_int32_inputs,
)
if return_dict:
return inputs
return (
inputs["alignment_heads"],
inputs["sot_sequence_length"],
inputs["segment_length"],
inputs["QKs"],
)
def create_torch_ops(self):
"""
1) Create UnfoldTensor and DynamicTimeWarping as torch ops
3) Provide a symbolic mapping from torch ops to ORT contrib ops
See https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html#building-with-jit-compilation
for more details on how this works.
"""
# Set torch extensions directory to cache directory
os.environ["TORCH_EXTENSIONS_DIR"] = self.cache_dir
# Try to import `ninja` pip package
try:
assert torch.utils.cpp_extension.verify_ninja_availability()
except Exception as e:
logger.error(f"An error occurred while verifying `ninja` is available: {e}", exc_info=True) # noqa: G201
install_cmd = "pip install ninja"
logger.warning(f"Could not import `ninja`. Attempting to install `ninja` via `{install_cmd}`.")
os.system(install_cmd)
# Create UnfoldTensor torch op
unfold_op_source = textwrap.dedent("""\
#include "torch/script.h"
torch::Tensor UnfoldTensor(torch::Tensor input, int64_t dim, int64_t size, int64_t step) {
return input.unfold(dim, size, step);
}
// namespace is onnxruntime
static auto registry = torch::RegisterOperators("onnxruntime::UnfoldTensor", &UnfoldTensor);
""")
torch.utils.cpp_extension.load_inline(
name="UnfoldTensor",
cpp_sources=unfold_op_source,
is_python_module=False,
verbose=True,
)
# Create DynamicTimeWarping torch op
dtw_op_source = textwrap.dedent("""\
#include "torch/script.h"
#include "torch/torch.h"
#include <stdexcept>
#include <tuple>
#include <vector>
torch::Tensor Backtrace(torch::Tensor trace) {
int64_t i = trace.size(0) - 1;
int64_t j = trace.size(1) - 1;
trace.index({0, torch::indexing::Slice()}) = 2;
trace.index({torch::indexing::Slice(), 0}) = 1;
std::vector<int32_t> result_vec;
while (i > 0 || j > 0) {
result_vec.push_back(static_cast<int32_t>(i - 1));
result_vec.push_back(static_cast<int32_t>(j - 1));
int value = trace[i][j].item<int>();
if (value == 0) {
i--;
j--;
} else if (value == 1) {
i--;
} else if (value == 2) {
j--;
} else {
throw std::runtime_error("Unexpected trace[i, j]");
}
}
// Compute result[::-1, :].T
torch::Tensor result = torch::from_blob(result_vec.data(), {static_cast<long int>(result_vec.size() / 2), 2}, torch::kInt32).clone();
torch::Tensor reversed = result.flip(0); // result[::-1, :]
torch::Tensor transposed = reversed.transpose(0, 1); // .T
return transposed;
}
torch::Tensor DynamicTimeWarping(torch::Tensor x) {
int64_t N = x.size(0);
int64_t M = x.size(1);
torch::Tensor cost = torch::full({N + 1, M + 1}, std::numeric_limits<float>::infinity(), torch::dtype(torch::kFloat32));
torch::Tensor trace = torch::full({N + 1, M + 1}, -1, torch::dtype(torch::kFloat32));
cost[0][0] = 0;
for (int j = 1; j < M + 1; j++) {
for (int i = 1; i < N + 1; i++) {
float c0 = cost[i - 1][j - 1].item<float>();
float c1 = cost[i - 1][j].item<float>();
float c2 = cost[i][j - 1].item<float>();
float c = 0;
float t = 0;
if (c0 < c1 && c0 < c2) {
c = c0;
t = 0;
} else if (c1 < c0 && c1 < c2) {
c = c1;
t = 1;
} else {
c = c2;
t = 2;
}
cost[i][j] = x[i - 1][j - 1].item<float>() + c;
trace[i][j] = t;
}
}
return Backtrace(trace);
}
// namespace is onnxruntime
static auto registry = torch::RegisterOperators("onnxruntime::DynamicTimeWarping", &DynamicTimeWarping);
""")
torch.utils.cpp_extension.load_inline(
name="DynamicTimeWarping",
cpp_sources=dtw_op_source,
is_python_module=False,
verbose=True,
)
# Create symbolic mapping from torch ops to ORT contrib ops
pytorch_export_contrib_ops.register()
def export_onnx(
self,
onnx_model_path: str,
provider: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_fp16_inputs: bool = False,
use_int32_inputs: bool = True,
):
"""Export word-level timestamps to ONNX
Args:
onnx_model_path (str): path to save ONNX model
provider (str): provider to use for verifying parity on ONNX model
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
"""
# Shape of timestamps's tensors:
# Inputs:
# alignment_heads: (num_alignment_heads, 2)
# sot_sequence_length: (1)
# segment_length: (1)
# cross_qk_*: (batch_size, num_heads, sequence_length, num_frames // 2)
# Outputs:
# jump_times: (batch_size, max_length)
# Definitions:
# alignment_heads: the attention head indices where the Q*K values are highly correlated with word-level timestamps
# (i.e. the alignment between audio and text tokens)
# This is calculated as follows:
#
# ```
# import base64
# import gzip
# import numpy as np
# import torch
#
# # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
# _ALIGNMENT_HEADS = {
# "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
# "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
# "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
# "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
# "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
# "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
# "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
# "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
# "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
# "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
# "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
# "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
# "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
# "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
# }
#
# model_name = "large-v3-turbo"
# array = np.frombuffer(
# gzip.decompress(base64.b85decode(_ALIGNMENT_HEADS[model_name])), dtype=bool
# ).copy()
# mask = torch.from_numpy(array).reshape(
# self.dims.n_text_layer, self.dims.n_text_head
# )
# self.alignment_heads = mask.to_sparse().indices().T
# ```
#
# sot_sequence_length: the length of the start-of-transcription sequence before the first token is generated
# Typically the start-of-transcription sequence is [<|startoftranscription|>, <|language_token|>, <|task_token|>]
# so its length is 3.
#
# segment_length: the length (in frames) of the audio segment that is being transcribed
#
# cross_qk_*: the Q*K values for the cross-attention blocks in the decoder
# Every decoder layer has a self-attention block and a cross-attention block so there are `n` cross-attention blocks
# where `n` is the number of decoder layers.
#
# jump_times: the timings where jumps occur in speech
# This allows us to detect when a word began to be spoken by the speaker (start_times) and when a word was finished
# being spoken by the speaker (end_times).
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
input_names = self.input_names()
output_names = self.output_names()
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
# Create torch ops and map them to ORT contrib ops before export
self.create_torch_ops()
torch.onnx.export(
self,
args=inputs,
f=out_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=17,
do_constant_folding=True,
verbose=verbose,
custom_opsets={"com.microsoft": 1},
)
if use_external_data_format:
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
OnnxModel.save(
model,
onnx_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
def verify_onnx(
self,
onnx_model_path: str,
provider: str,
use_fp16_inputs: bool,
use_int32_inputs: bool,
):
"""Verify ONNX model outputs and PyTorch model outputs match
Args:
onnx_model_path (str): path to save ONNX model
provider (str): execution provider for ONNX model
use_fp16_inputs (bool, optional): use float16 inputs for the cross_qk_{i}
use_int32_inputs (bool, optional): use int32 inputs for the alignment_heads and sot_sequence_length
"""
# Shape of jump times's tensors:
# Inputs:
# alignment_heads: (num_alignment_heads, 2)
# sot_sequence_length: (1)
# segment_length: (1)
# cross_qk_*: (batch_size, num_heads, sequence_length, num_frames // 2)
# Outputs:
# jump_times: (batch_size, max_length)
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
# Run PyTorch model
pt_outputs = (
self.forward(
inputs["alignment_heads"], inputs["sot_sequence_length"], inputs["segment_length"], inputs["QKs"]
)
.detach()
.cpu()
.numpy()
)
# Run ONNX model
sess = InferenceSession(onnx_model_path, providers=[provider])
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
# Calculate output difference
diff = np.abs(pt_outputs - ort_outputs)
print("Comparing batched jump_times...", flush=True)
print(f"Max diff: {np.max(diff)}", flush=True)