switching to high quality piper tts and added label translations
This commit is contained in:
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+98
@@ -0,0 +1,98 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from utils import (
|
||||
chain_enc_dec_with_beamsearch,
|
||||
export_summarization_edinit,
|
||||
export_summarization_enc_dec_past,
|
||||
onnx_inference,
|
||||
)
|
||||
|
||||
# GLOBAL ENVS
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||
stream=sys.stdout,
|
||||
)
|
||||
logger = logging.getLogger("generate")
|
||||
|
||||
|
||||
def print_args(args):
|
||||
for arg in vars(args):
|
||||
logger.info(f"{arg}: {getattr(args, arg)}")
|
||||
|
||||
|
||||
def user_command():
|
||||
parent_parser = argparse.ArgumentParser(add_help=False)
|
||||
parent_parser.add_argument("--max_length", type=int, default=20, help="default to 20")
|
||||
parent_parser.add_argument("--min_length", type=int, default=0, help="default to 0")
|
||||
parent_parser.add_argument("-o", "--output", type=str, default="onnx_models", help="default name is onnx_models.")
|
||||
parent_parser.add_argument("-i", "--input_text", type=str, default=None, help="input text")
|
||||
parent_parser.add_argument("-s", "--spm_path", type=str, default=None, help="tokenizer model from sentencepice")
|
||||
parent_parser.add_argument("-v", "--vocab_path", type=str, help="vocab dictionary")
|
||||
parent_parser.add_argument("-b", "--num_beams", type=int, default=5, help="default to 5")
|
||||
parent_parser.add_argument("--repetition_penalty", type=float, default=1.0, help="default to 1.0")
|
||||
parent_parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
|
||||
parent_parser.add_argument("--early_stopping", type=bool, default=False, help="default to False")
|
||||
parent_parser.add_argument("--opset_version", type=int, default=14, help="minimum is 14")
|
||||
|
||||
parent_parser.add_argument("--no_encoder", action="store_true")
|
||||
parent_parser.add_argument("--no_decoder", action="store_true")
|
||||
parent_parser.add_argument("--no_chain", action="store_true")
|
||||
parent_parser.add_argument("--no_inference", action="store_true")
|
||||
|
||||
required_args = parent_parser.add_argument_group("required input arguments")
|
||||
required_args.add_argument(
|
||||
"-m",
|
||||
"--model_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The directory contains input huggingface model. \
|
||||
An official model like facebook/bart-base is also acceptable.",
|
||||
)
|
||||
|
||||
print_args(parent_parser.parse_args())
|
||||
return parent_parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = user_command()
|
||||
if args.opset_version < 14:
|
||||
raise ValueError(f"The minimum supported opset version is 14! The given one was {args.opset_version}.")
|
||||
|
||||
isExist = os.path.exists(args.output) # noqa: N816
|
||||
if not isExist:
|
||||
os.makedirs(args.output)
|
||||
|
||||
# beam search op only supports CPU for now
|
||||
args.device = "cpu"
|
||||
logger.info("ENV: CPU ...")
|
||||
|
||||
if not args.input_text:
|
||||
args.input_text = (
|
||||
"PG&E stated it scheduled the blackouts in response to forecasts for high winds "
|
||||
"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
|
||||
"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
|
||||
)
|
||||
|
||||
if not args.no_encoder:
|
||||
logger.info("========== EXPORTING ENCODER ==========")
|
||||
export_summarization_edinit.export_encoder(args)
|
||||
if not args.no_decoder:
|
||||
logger.info("========== EXPORTING DECODER ==========")
|
||||
export_summarization_enc_dec_past.export_decoder(args)
|
||||
if not args.no_chain:
|
||||
logger.info("========== CONVERTING MODELS ==========")
|
||||
chain_enc_dec_with_beamsearch.convert_model(args)
|
||||
if not args.no_inference:
|
||||
logger.info("========== INFERENCING WITH ONNX MODEL ==========")
|
||||
onnx_inference.run_inference(args)
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+329
@@ -0,0 +1,329 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# This script evaluates accuracy of ONNX models for question-answering task on SQuAD data set.
|
||||
# Example to evaluate raw and optimized model for CUDA in Linux:
|
||||
# pip3 install datasets evaluate optimum transformers onnxruntime-gpu
|
||||
#
|
||||
# python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding
|
||||
#
|
||||
# python3 -m onnxruntime.transformers.optimizer \
|
||||
# --input ./bert-large-uncased-whole-word-masking-finetuned-squad/model.onnx \
|
||||
# --output ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
|
||||
#
|
||||
# python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding \
|
||||
# --onnx ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
|
||||
#
|
||||
# Snippet of example output in A100:
|
||||
# {'exact': 86.65089877010406, 'f1': 92.99433524952254, 'total': 10570, 'HasAns_exact': 86.65089877010406
|
||||
# 'total_time_in_seconds': 81.69239814393222, 'samples_per_second': 129.387804008115,
|
||||
# 'latency_in_seconds': 0.007728703703304846, 'provider': 'CUDAExecutionProvider',
|
||||
# 'pretrained_model_name': 'bert-large-uncased-whole-word-masking-finetuned-squad',
|
||||
# 'batch_size': 1, 'sequence_length': 384, 'use_io_binding': True}
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
|
||||
try:
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
except ImportError:
|
||||
from importlib_metadata import PackageNotFoundError, version
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from datasets import load_dataset
|
||||
from evaluate import evaluator
|
||||
from optimum.onnxruntime import ORTModelForQuestionAnswering
|
||||
from optimum.version import __version__ as optimum_version
|
||||
from packaging import version as version_check
|
||||
from transformers import AutoTokenizer, pipeline
|
||||
|
||||
if version_check.parse(optimum_version) < version_check.parse("1.13.1"):
|
||||
raise ImportError(f"Please install optimum>=1.13.1. Current version: {optimum_version}.")
|
||||
|
||||
PRETRAINED_SQUAD_MODELS = [
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
||||
"deepset/roberta-base-squad2",
|
||||
"distilbert-base-cased-distilled-squad",
|
||||
]
|
||||
|
||||
|
||||
def get_package_version(package_name: str):
|
||||
try:
|
||||
return version(package_name)
|
||||
except PackageNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def load_onnx_model(
|
||||
model_id: str, onnx_path: str | None = None, provider="CUDAExecutionProvider", use_io_binding: bool = False
|
||||
):
|
||||
"""Load onnx model given pretrained model name and optional ONNX model path. If onnx_path is None,
|
||||
the default onnx model from optimum will be used.
|
||||
|
||||
Args:
|
||||
model_id (str): pretrained model name or checkpoint path
|
||||
onnx_path (Optional[str], optional): path of onnx model to evaluate. Defaults to None.
|
||||
|
||||
Returns:
|
||||
model: ORTModel for the onnx model
|
||||
onnx_path: the path of onnx model
|
||||
"""
|
||||
|
||||
if onnx_path is None:
|
||||
# Export onnx to a sub-directory named by the model id
|
||||
model = ORTModelForQuestionAnswering.from_pretrained(
|
||||
model_id, export=True, provider=provider, use_io_binding=use_io_binding
|
||||
)
|
||||
save_onnx_dir = os.path.join(".", model_id)
|
||||
model.save_pretrained(save_onnx_dir)
|
||||
onnx_path = os.path.join(save_onnx_dir, "model.onnx")
|
||||
print("Model is exported to onnx file:", onnx_path)
|
||||
else:
|
||||
model = ORTModelForQuestionAnswering.from_pretrained(
|
||||
os.path.dirname(onnx_path),
|
||||
file_name=Path(onnx_path).name,
|
||||
provider=provider,
|
||||
use_io_binding=use_io_binding,
|
||||
# provider_options={"enable_skip_layer_norm_strict_mode": True},
|
||||
)
|
||||
|
||||
return model, onnx_path
|
||||
|
||||
|
||||
def output_details(results: list[dict[str, Any]], csv_filename: str):
|
||||
"""Output a CSV file with detail of each test results.
|
||||
|
||||
Args:
|
||||
results (List[Dict[str, Any]]): list of JSON results.
|
||||
csv_filename (str): path of output CSV file
|
||||
"""
|
||||
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
||||
column_names = [
|
||||
"pretrained_model_name",
|
||||
"onnx_path",
|
||||
"provider",
|
||||
"disable_fused_attention",
|
||||
"batch_size",
|
||||
"sequence_length",
|
||||
"use_io_binding",
|
||||
"exact",
|
||||
"f1",
|
||||
"total",
|
||||
"HasAns_exact",
|
||||
"HasAns_f1",
|
||||
"HasAns_total",
|
||||
"best_exact",
|
||||
"best_exact_thresh",
|
||||
"best_f1",
|
||||
"best_f1_thresh",
|
||||
"total_time_in_seconds",
|
||||
"samples_per_second",
|
||||
"latency_in_seconds",
|
||||
]
|
||||
|
||||
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
||||
csv_writer.writeheader()
|
||||
for result in results:
|
||||
csv_writer.writerow(result)
|
||||
|
||||
csv_file.flush()
|
||||
|
||||
print(f"Detail results are saved to csv file: {csv_filename}")
|
||||
|
||||
|
||||
def output_summary(results: list[dict[str, Any]], csv_filename: str, metric_name: str):
|
||||
"""Output a CSV file with summary of a metric on combinations of batch_size and sequence_length.
|
||||
|
||||
Args:
|
||||
results (List[Dict[str, Any]]): list of JSON results.
|
||||
csv_filename (str): path of output CSV file
|
||||
metric_name (str): the metric to summarize
|
||||
"""
|
||||
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
||||
header_names = [
|
||||
"pretrained_model_name",
|
||||
"onnx_path",
|
||||
"provider",
|
||||
"disable_fused_attention",
|
||||
"use_io_binding",
|
||||
]
|
||||
|
||||
model_list = list({result["onnx_path"] for result in results})
|
||||
model_list.sort()
|
||||
|
||||
batch_sizes = list({result["batch_size"] for result in results})
|
||||
batch_sizes.sort()
|
||||
|
||||
sequence_lengths = list({result["sequence_length"] for result in results})
|
||||
sequence_lengths.sort()
|
||||
|
||||
key_names = []
|
||||
for sequence_length in sequence_lengths:
|
||||
for batch_size in batch_sizes:
|
||||
key_names.append(f"b{batch_size}_s{sequence_length}")
|
||||
|
||||
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + key_names)
|
||||
csv_writer.writeheader()
|
||||
|
||||
for model in model_list:
|
||||
row = {}
|
||||
|
||||
# Metric value for given pair of batch_size and sequence_length.
|
||||
# Assume that (onnx_path, batch_size and sequence_length) are unique so keep first occurrence only.
|
||||
values = {}
|
||||
values.update(dict.fromkeys(key_names, ""))
|
||||
|
||||
for result in results:
|
||||
if result["onnx_path"] == model and result[metric_name]:
|
||||
headers = {k: v for k, v in result.items() if k in header_names}
|
||||
if not row:
|
||||
row.update(headers)
|
||||
|
||||
batch_size = result["batch_size"]
|
||||
sequence_length = result["sequence_length"]
|
||||
key = f"b{batch_size}_s{sequence_length}"
|
||||
|
||||
if key in key_names:
|
||||
values[key] = result[metric_name]
|
||||
|
||||
if row:
|
||||
for key in key_names:
|
||||
row[key] = values.get(key, "")
|
||||
csv_writer.writerow(row)
|
||||
|
||||
csv_file.flush()
|
||||
|
||||
print(f"Summary results for {metric_name} are saved to csv file: {csv_filename}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
print(args)
|
||||
|
||||
for name in ["onnxruntime-gpu", "onnxruntime", "onnx", "torch", "transformers", "optimum", "datasets", "evaluate"]:
|
||||
package_version = get_package_version(name)
|
||||
if package_version:
|
||||
print(f"{name} version", package_version)
|
||||
|
||||
pretrained_model_name = args.model_name
|
||||
if args.onnx and not os.path.exists(args.onnx):
|
||||
raise RuntimeError(f"Onnx model path does not exist: {args.onnx}")
|
||||
|
||||
disable_fused_attention = os.environ.get("ORT_DISABLE_FUSED_ATTENTION", "0") == "1"
|
||||
|
||||
all_results = []
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
||||
for sequence_length in args.sequence_lengths:
|
||||
tokenizer.model_max_length = sequence_length
|
||||
tokenizer.doc_stride = min(sequence_length // 2, 128)
|
||||
if args.onnx is None:
|
||||
print("Exporting onnx model. It might take a few minutes...")
|
||||
start_time = time.time()
|
||||
ort_model, onnx_path = load_onnx_model(pretrained_model_name, args.onnx, args.provider, args.use_io_binding)
|
||||
latency = time.time() - start_time
|
||||
print(f"Onnx model exported or loaded in {latency:.1f} seconds")
|
||||
|
||||
print(ort_model.config)
|
||||
if sequence_length > ort_model.config.max_position_embeddings:
|
||||
raise RuntimeError("sequence length should not be larger than {ort_model.config.max_position_embeddings}")
|
||||
|
||||
qa_pipeline = pipeline(
|
||||
"question-answering", model=ort_model, tokenizer=tokenizer, question_first=True, batch_size=args.batch_size
|
||||
)
|
||||
|
||||
task_evaluator = evaluator("question-answering")
|
||||
print("Loading dataset...")
|
||||
start_time = time.time()
|
||||
squad_dataset = load_dataset("squad", split=f"validation[:{args.total}]" if args.total > 0 else "validation")
|
||||
latency = time.time() - start_time
|
||||
print(f"Dataset loaded in {latency:.1f} seconds")
|
||||
|
||||
print("Evaluating squad_v2 with ORT. It might take a few minutes...")
|
||||
start_time = time.time()
|
||||
result = task_evaluator.compute(
|
||||
model_or_pipeline=qa_pipeline,
|
||||
data=squad_dataset,
|
||||
metric="squad_v2",
|
||||
squad_v2_format=True,
|
||||
)
|
||||
latency = time.time() - start_time
|
||||
print(f"Evaluation done in {latency:.1f} seconds")
|
||||
|
||||
result["provider"] = args.provider
|
||||
result["disable_fused_attention"] = disable_fused_attention
|
||||
result["pretrained_model_name"] = pretrained_model_name
|
||||
result["onnx_path"] = onnx_path
|
||||
result["batch_size"] = args.batch_size
|
||||
result["sequence_length"] = sequence_length
|
||||
result["use_io_binding"] = args.use_io_binding
|
||||
print(result)
|
||||
|
||||
all_results.append(result)
|
||||
|
||||
output_details(all_results, "detail.csv")
|
||||
|
||||
for metric_name in ["f1", "exact", "samples_per_second"]:
|
||||
output_summary(all_results, f"{metric_name}.csv", metric_name)
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default=PRETRAINED_SQUAD_MODELS[0],
|
||||
help=f"Checkpoint directory or pre-trained model names in the list: {PRETRAINED_SQUAD_MODELS}",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sequence_lengths",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[384],
|
||||
help="Sequence lengths for onnx model inputs. It could have multiple values.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="batch size for inference.",
|
||||
)
|
||||
|
||||
parser.add_argument("-t", "--total", type=int, default=0, help="Total samples to test. 0 means all samples.")
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional onnx model path. If not specified, optimum will be used to export onnx model for testing.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
required=False,
|
||||
default="CUDAExecutionProvider",
|
||||
help="Select which Execution Provider to use for runs. Default is CUDAExecutionProvider.",
|
||||
)
|
||||
|
||||
parser.add_argument("--use_io_binding", required=False, action="store_true", help="Use IO Binding for GPU.")
|
||||
parser.set_defaults(use_io_binding=False)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+413
@@ -0,0 +1,413 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
# This script benchmarks gpt2 model with past state.
|
||||
# For gpt2 model without past state, use benchmark.py to measure performance.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from benchmark_helper import (
|
||||
Precision,
|
||||
create_onnxruntime_session,
|
||||
get_ort_environment_variables,
|
||||
prepare_environment,
|
||||
setup_logger,
|
||||
)
|
||||
from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
|
||||
from packaging import version
|
||||
from quantize_helper import QuantizeHelper
|
||||
from transformers import AutoConfig
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name_or_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Model path, or pretrained model name selected in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_class",
|
||||
required=False,
|
||||
type=str,
|
||||
default="GPT2LMHeadModel",
|
||||
choices=list(MODEL_CLASSES.keys()),
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "cache_models"),
|
||||
help="Directory to cache pre-trained models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "onnx_models"),
|
||||
help="Directory to store onnx models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test_times",
|
||||
required=False,
|
||||
default=100,
|
||||
type=int,
|
||||
help="Number of repeat times to get average inference latency.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--validate_onnx",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Validate ONNX model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--optimize_onnx",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use optimizer.py to optimize onnx model",
|
||||
)
|
||||
parser.set_defaults(optimize_onnx=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
type=int,
|
||||
default=0,
|
||||
required=False,
|
||||
choices=[0, 1, 2],
|
||||
help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
|
||||
"1 - decode the first token when past_sequence_length is zero; "
|
||||
"2 - decode the remaining tokens when past_sequence_length is not zero; "
|
||||
"0 - one onnx model for both stages 1 and 2. "
|
||||
"Note that we will optimize 1 and 2 differently for best performance.",
|
||||
)
|
||||
|
||||
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
|
||||
parser.set_defaults(use_gpu=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=Precision,
|
||||
default=Precision.FLOAT32,
|
||||
choices=list(Precision),
|
||||
help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization",
|
||||
)
|
||||
|
||||
parser.add_argument("--torchscript", required=False, action="store_true", help="use Torchscript")
|
||||
parser.set_defaults(torchscript=False)
|
||||
|
||||
parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1], help="batch size")
|
||||
|
||||
parser.add_argument(
|
||||
"--sequence_lengths",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="sequence lengths (excluding past)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--past_sequence_lengths",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[8, 16, 32, 64, 128, 256],
|
||||
help="past sequence lengths",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--result_csv",
|
||||
required=False,
|
||||
default=None,
|
||||
help="CSV file for saving summary results.",
|
||||
)
|
||||
|
||||
parser.add_argument("--thread_num", required=False, type=int, default=-1, help="Threads to use")
|
||||
|
||||
parser.add_argument("--include_copy_output_latency", required=False, action="store_true")
|
||||
parser.set_defaults(include_copy_output_latency=False)
|
||||
|
||||
parser.add_argument("--verbose", required=False, action="store_true")
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument("--output_torch_latency", required=False, action="store_true")
|
||||
parser.set_defaults(output_torch_latency=False)
|
||||
|
||||
parser.add_argument("--disable_io_binding", required=False, action="store_true")
|
||||
parser.set_defaults(disable_io_binding=False)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
if version.parse(transformers_version) < version.parse(
|
||||
"3.1.0"
|
||||
): # past_key_values name does not exist in 3.0.2 or older
|
||||
raise RuntimeError("This tool requires transformers 3.1.0 or later.")
|
||||
|
||||
logger.info(f"Arguments:{args}")
|
||||
if args.precision == Precision.FLOAT16:
|
||||
assert args.optimize_onnx and args.use_gpu, "fp16 requires --optimize_onnx --use_gpu"
|
||||
|
||||
if args.precision == Precision.INT8:
|
||||
assert not args.use_gpu, "quantization only supports CPU"
|
||||
|
||||
if args.stage == 1:
|
||||
assert args.past_sequence_lengths == [0], "past_sequence_lengths shall be 0 for stage==1 (init decoder)"
|
||||
|
||||
torch.set_num_threads(psutil.cpu_count(logical=True) if args.thread_num <= 0 else args.thread_num)
|
||||
print(torch.__config__.parallel_info())
|
||||
|
||||
cache_dir = args.cache_dir
|
||||
output_dir = args.onnx_dir
|
||||
prepare_environment(cache_dir, output_dir, args.use_gpu)
|
||||
|
||||
model_class = MODEL_CLASSES[args.model_class][0]
|
||||
gpt2helper = Gpt2Helper
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path, torchscript=args.torchscript, cache_dir=cache_dir)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
|
||||
|
||||
# This script does not support float16 for PyTorch.
|
||||
# if args.float16:
|
||||
# model.half()
|
||||
|
||||
device = torch.device("cuda:0" if args.use_gpu else "cpu")
|
||||
model.to(device)
|
||||
use_external_data_format = config.n_layer > 24 # TODO: find a way to check model size > 2GB
|
||||
onnx_model_paths = gpt2helper.get_onnx_paths(
|
||||
output_dir,
|
||||
args.model_name_or_path,
|
||||
args.model_class,
|
||||
has_past=True,
|
||||
new_folder=use_external_data_format,
|
||||
)
|
||||
|
||||
onnx_model_path = onnx_model_paths["raw"]
|
||||
use_padding = MODEL_CLASSES[args.model_class][2]
|
||||
gpt2helper.export_onnx(
|
||||
model,
|
||||
device,
|
||||
onnx_model_path,
|
||||
args.verbose,
|
||||
use_external_data_format,
|
||||
has_position_ids=use_padding,
|
||||
has_attention_mask=use_padding,
|
||||
)
|
||||
|
||||
if args.optimize_onnx or args.precision != Precision.FLOAT32:
|
||||
onnx_model_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
|
||||
gpt2helper.optimize_onnx(
|
||||
onnx_model_paths["raw"],
|
||||
onnx_model_path,
|
||||
args.precision == Precision.FLOAT16,
|
||||
model.config.num_attention_heads,
|
||||
model.config.hidden_size,
|
||||
use_external_data_format,
|
||||
auto_mixed_precision=True,
|
||||
stage=args.stage,
|
||||
)
|
||||
|
||||
if args.precision == Precision.INT8:
|
||||
logger.info("quantizing model...")
|
||||
QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_paths["int8"], use_external_data_format)
|
||||
model = QuantizeHelper.quantize_torch_model(model)
|
||||
logger.info("finished quantizing model")
|
||||
onnx_model_path = onnx_model_paths["int8"]
|
||||
|
||||
if args.torchscript:
|
||||
model = gpt2helper.torchscript(
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
has_position_ids=use_padding,
|
||||
has_attention_mask=use_padding,
|
||||
)
|
||||
|
||||
session = create_onnxruntime_session(
|
||||
onnx_model_path,
|
||||
args.use_gpu,
|
||||
enable_all_optimization=False,
|
||||
num_threads=args.thread_num,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
if session is None:
|
||||
return
|
||||
|
||||
# Allocate output buffers for IO Binding
|
||||
max_output_shapes = gpt2helper.get_output_shapes(
|
||||
max(args.batch_sizes),
|
||||
max(args.past_sequence_lengths),
|
||||
max(args.sequence_lengths),
|
||||
config,
|
||||
args.model_class,
|
||||
)
|
||||
output_buffers = gpt2helper.get_output_buffers(max_output_shapes, device, args.precision == Precision.FLOAT16)
|
||||
|
||||
csv_filename = args.result_csv or "benchmark_result_{}.csv".format(datetime.now().strftime("%Y%m%d-%H%M%S"))
|
||||
with open(csv_filename, mode="a", newline="") as csv_file:
|
||||
column_names = [
|
||||
"model_name",
|
||||
"model_class",
|
||||
"stage",
|
||||
"environment_variables",
|
||||
"gpu",
|
||||
"precision",
|
||||
"optimizer",
|
||||
"torchscript",
|
||||
"batch_size",
|
||||
"sequence_length",
|
||||
"past_sequence_length",
|
||||
"disable_io_binding",
|
||||
"torch_latency",
|
||||
"onnxruntime_latency",
|
||||
]
|
||||
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
||||
csv_writer.writeheader()
|
||||
|
||||
for batch_size in args.batch_sizes:
|
||||
for sequence_length in args.sequence_lengths:
|
||||
for past_sequence_length in args.past_sequence_lengths:
|
||||
assert batch_size > 0 and sequence_length > 0 and past_sequence_length >= 0
|
||||
logger.debug(
|
||||
"Running test for batch_size=%d sequence_length=%d past_sequence_length=%d ...",
|
||||
batch_size,
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
)
|
||||
|
||||
dummy_inputs = gpt2helper.get_dummy_inputs(
|
||||
batch_size,
|
||||
past_sequence_length,
|
||||
sequence_length,
|
||||
config.num_attention_heads,
|
||||
config.hidden_size,
|
||||
config.n_layer,
|
||||
config.vocab_size,
|
||||
device,
|
||||
float16=(args.precision == Precision.FLOAT16),
|
||||
has_position_ids=use_padding,
|
||||
has_attention_mask=use_padding,
|
||||
)
|
||||
output_shapes = gpt2helper.get_output_shapes(
|
||||
batch_size,
|
||||
past_sequence_length,
|
||||
sequence_length,
|
||||
config,
|
||||
args.model_class,
|
||||
)
|
||||
|
||||
try:
|
||||
if args.validate_onnx or args.output_torch_latency:
|
||||
outputs, torch_latency = gpt2helper.pytorch_inference(model, dummy_inputs, args.test_times)
|
||||
|
||||
# Dump Torch output shape
|
||||
for i, value in enumerate(outputs):
|
||||
if isinstance(value, tuple):
|
||||
logger.debug(
|
||||
f"torch output {i} is tuple of size {len(value)}, shape {value[0].shape}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"torch output {i} shape {value.shape}")
|
||||
else:
|
||||
outputs = None
|
||||
torch_latency = None
|
||||
|
||||
if args.disable_io_binding:
|
||||
ort_outputs, ort_latency = gpt2helper.onnxruntime_inference(
|
||||
session, dummy_inputs, args.test_times
|
||||
)
|
||||
else:
|
||||
ort_outputs, ort_latency = gpt2helper.onnxruntime_inference_with_binded_io(
|
||||
session,
|
||||
dummy_inputs,
|
||||
output_buffers,
|
||||
output_shapes,
|
||||
args.test_times,
|
||||
return_numpy=False,
|
||||
include_copy_output_latency=args.include_copy_output_latency,
|
||||
)
|
||||
|
||||
if args.validate_onnx:
|
||||
copy_outputs = ort_outputs
|
||||
if not args.disable_io_binding:
|
||||
# Results of IO binding might be in GPU. Copy outputs to CPU for comparison.
|
||||
copy_outputs = []
|
||||
for output in ort_outputs:
|
||||
copy_outputs.append(output.cpu().numpy())
|
||||
|
||||
if gpt2helper.compare_outputs(
|
||||
outputs,
|
||||
copy_outputs,
|
||||
model_class=args.model_class,
|
||||
rtol=DEFAULT_TOLERANCE[args.precision],
|
||||
atol=DEFAULT_TOLERANCE[args.precision],
|
||||
):
|
||||
logger.info(
|
||||
f"Pytorch and ONNX Runtime outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]})."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"batch_size=%d, sequence_length=%d, past_sequence_length=%d, onnxruntime_latency=%.2f %s %s",
|
||||
batch_size,
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
ort_latency,
|
||||
"(disable_io_binding)" if args.disable_io_binding else "",
|
||||
", torch_latency={torch_latency}" if torch_latency else "",
|
||||
)
|
||||
|
||||
row = {
|
||||
"model_name": args.model_name_or_path,
|
||||
"model_class": args.model_class,
|
||||
"stage": args.stage,
|
||||
"environment_variables": get_ort_environment_variables(),
|
||||
"gpu": args.use_gpu,
|
||||
"precision": args.precision,
|
||||
"optimizer": args.optimize_onnx,
|
||||
"torchscript": args.torchscript,
|
||||
"batch_size": batch_size,
|
||||
"sequence_length": sequence_length,
|
||||
"past_sequence_length": past_sequence_length,
|
||||
"disable_io_binding": args.disable_io_binding,
|
||||
"torch_latency": f"{torch_latency:.2f}" if torch_latency else "None",
|
||||
"onnxruntime_latency": f"{ort_latency:.2f}",
|
||||
}
|
||||
csv_writer.writerow(row)
|
||||
except Exception:
|
||||
logger.error("Exception", exc_info=True) # noqa: G201
|
||||
return None
|
||||
|
||||
logger.info(f"Results are saved to file {csv_filename}")
|
||||
return csv_filename
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
setup_logger(args.verbose)
|
||||
main(args)
|
||||
+558
@@ -0,0 +1,558 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
"""
|
||||
This converts GPT2 model to onnx. Examples:
|
||||
(1) Convert pretrained model 'gpt2' to ONNX
|
||||
python convert_to_onnx.py -m gpt2 --output gpt2.onnx
|
||||
(2) Convert pretrained model 'distilgpt2' to ONNX, and use optimizer to get float16 model.
|
||||
python convert_to_onnx.py -m distilgpt2 --output distilgpt2_fp16.onnx -o -p fp16
|
||||
(3) Convert a model check point to ONNX, and run optimization and int8 quantization
|
||||
python convert_to_onnx.py -m ./my_model_checkpoint/ --output my_model_int8.onnx -o -p int8
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from benchmark_helper import (
|
||||
Precision,
|
||||
create_onnxruntime_session,
|
||||
get_ort_environment_variables,
|
||||
prepare_environment,
|
||||
setup_logger,
|
||||
)
|
||||
from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
|
||||
from gpt2_tester import Gpt2Tester
|
||||
from packaging import version
|
||||
from quantize_helper import QuantizeHelper
|
||||
from transformers import AutoConfig
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
from onnxruntime import __version__ as ort_version
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name_or_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_class",
|
||||
required=False,
|
||||
type=str,
|
||||
default="GPT2LMHeadModel",
|
||||
choices=list(MODEL_CLASSES.keys()),
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "cache_models"),
|
||||
help="Directory to cache pre-trained models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "onnx_models"),
|
||||
help="Output directory, or model path ends with .onnx",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--optimize_onnx",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use optimizer.py to optimize onnx model",
|
||||
)
|
||||
parser.set_defaults(optimize_onnx=False)
|
||||
|
||||
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
|
||||
parser.set_defaults(use_gpu=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
required=False,
|
||||
default=None,
|
||||
choices=["dml", "rocm", "migraphx", "cuda", "tensorrt"],
|
||||
help="use dml, rocm, cuda, tensorrt or migraphx for respective backend",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tolerance",
|
||||
required=False,
|
||||
type=float,
|
||||
default=0,
|
||||
help="the absolute and relative tolerance for parity verification",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_test_file",
|
||||
"-i",
|
||||
required=False,
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to the file with inputs to test with",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
required=False,
|
||||
type=Precision,
|
||||
default=Precision.FLOAT32,
|
||||
choices=list(Precision),
|
||||
help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision, and int8 for quantization",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--test_cases",
|
||||
required=False,
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of test cases per run for parity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--test_runs",
|
||||
required=False,
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of runs for parity. It is used for significance test.",
|
||||
)
|
||||
|
||||
parser.add_argument("--verbose", required=False, action="store_true")
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
|
||||
parser.set_defaults(use_external_data_format=False)
|
||||
|
||||
parser.add_argument("--overwrite", required=False, action="store_true")
|
||||
parser.set_defaults(overwrite=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_int64_inputs",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use int32 instead of int64 for input_ids, position_ids and attention_mask.",
|
||||
)
|
||||
parser.set_defaults(use_int64_inputs=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--stage",
|
||||
type=int,
|
||||
default=0,
|
||||
required=False,
|
||||
choices=[0, 1, 2],
|
||||
help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
|
||||
"1 - decode the first token when past_sequence_length is zero; "
|
||||
"2 - decode the remaining tokens when past_sequence_length is not zero; "
|
||||
"0 - one onnx model for both stages 1 and 2. "
|
||||
"Note that we will optimize 1 and 2 differently for best performance.",
|
||||
)
|
||||
|
||||
fp16_option_group = parser.add_argument_group(
|
||||
'float to float16 conversion parameters that works when "--precision fp16" is specified'
|
||||
)
|
||||
|
||||
fp16_option_group.add_argument(
|
||||
"-a",
|
||||
"--auto_mixed_precision",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Convert to mixed precision automatically. Other float16 conversion parameters will be ignored.",
|
||||
)
|
||||
fp16_option_group.set_defaults(auto_mixed_precision=False)
|
||||
|
||||
fp16_option_group.add_argument(
|
||||
"--keep_io_types",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use float32 for past inputs, present and logits outputs.",
|
||||
)
|
||||
fp16_option_group.set_defaults(keep_io_types=False)
|
||||
|
||||
fp16_option_group.add_argument(
|
||||
"--io_block_list",
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="List of inputs or outputs in float32 instead of float16",
|
||||
)
|
||||
|
||||
fp16_option_group.add_argument(
|
||||
"--op_block_list",
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="List of operators (like Add LayerNormalization SkipLayerNormalization EmbedLayerNormalization FastGelu) "
|
||||
"to compute in float32 instead of float16.",
|
||||
)
|
||||
|
||||
fp16_option_group.add_argument(
|
||||
"--node_block_list",
|
||||
nargs="+",
|
||||
default=[],
|
||||
help="List of node names to compute in float32 instead of float16.",
|
||||
)
|
||||
|
||||
fp16_option_group.add_argument(
|
||||
"--force_fp16_initializers",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Convert all float initializers to float16.",
|
||||
)
|
||||
fp16_option_group.set_defaults(force_fp16_initializers=False)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_onnx_model_size(onnx_path: str, use_external_data_format: bool):
|
||||
if not use_external_data_format:
|
||||
return os.path.getsize(onnx_path)
|
||||
else:
|
||||
return sum([f.stat().st_size for f in Path(onnx_path).parent.rglob("*")])
|
||||
|
||||
|
||||
def get_latency_name(batch_size, sequence_length, past_sequence_length):
|
||||
return f"average_latency(batch_size={batch_size},sequence_length={sequence_length},past_sequence_length={past_sequence_length})"
|
||||
|
||||
|
||||
def main(argv=None, experiment_name: str = "", run_id: str = "0", csv_filename: str = "gpt2_parity_results.csv"):
|
||||
result = {}
|
||||
if version.parse(transformers_version) < version.parse(
|
||||
"3.1.0"
|
||||
): # past_key_values name does not exist in 3.0.2 or older
|
||||
raise RuntimeError("This tool requires transformers 3.1.0 or later.")
|
||||
|
||||
args = parse_arguments(argv)
|
||||
setup_logger(args.verbose)
|
||||
|
||||
if not experiment_name:
|
||||
experiment_name = " ".join(argv if argv else sys.argv[1:])
|
||||
|
||||
if args.tolerance == 0:
|
||||
args.tolerance = DEFAULT_TOLERANCE[args.precision]
|
||||
|
||||
logger.info(f"Arguments:{args}")
|
||||
|
||||
cache_dir = args.cache_dir
|
||||
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
|
||||
prepare_environment(cache_dir, output_dir, args.use_gpu)
|
||||
|
||||
if args.precision != Precision.FLOAT32:
|
||||
assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
|
||||
|
||||
if args.precision == Precision.FLOAT16:
|
||||
assert args.use_gpu, "fp16 requires --use_gpu"
|
||||
|
||||
if args.precision == Precision.INT8:
|
||||
assert not args.use_gpu, "quantization only supports CPU"
|
||||
|
||||
model_class = MODEL_CLASSES[args.model_class][0]
|
||||
use_padding = MODEL_CLASSES[args.model_class][2]
|
||||
|
||||
gpt2helper = Gpt2Helper
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=cache_dir)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
|
||||
|
||||
device = torch.device("cuda:0" if args.use_gpu else "cpu")
|
||||
model.eval().to(device)
|
||||
|
||||
if (not args.use_external_data_format) and (config.n_layer > 24):
|
||||
logger.info("Try --use_external_data_format when model size > 2GB")
|
||||
|
||||
onnx_model_paths = gpt2helper.get_onnx_paths(
|
||||
output_dir,
|
||||
args.model_name_or_path,
|
||||
args.model_class,
|
||||
new_folder=(args.precision == Precision.INT8),
|
||||
remove_existing=["fp32", "fp16", "int8"],
|
||||
) # Do not remove raw model to save time in parity test
|
||||
|
||||
raw_onnx_model = onnx_model_paths["raw"]
|
||||
|
||||
int_data_type = torch.int64 if args.use_int64_inputs else torch.int32
|
||||
|
||||
if os.path.exists(raw_onnx_model) and not args.overwrite:
|
||||
logger.warning(f"Skip exporting ONNX model since it existed: {raw_onnx_model}")
|
||||
else:
|
||||
logger.info(f"Exporting ONNX model to {raw_onnx_model}")
|
||||
gpt2helper.export_onnx(
|
||||
model,
|
||||
device,
|
||||
raw_onnx_model,
|
||||
args.verbose,
|
||||
args.use_external_data_format,
|
||||
has_position_ids=use_padding,
|
||||
has_attention_mask=use_padding,
|
||||
input_ids_dtype=int_data_type,
|
||||
position_ids_dtype=int_data_type,
|
||||
attention_mask_dtype=int_data_type,
|
||||
)
|
||||
|
||||
fp16_params = {"keep_io_types": args.keep_io_types}
|
||||
if args.io_block_list:
|
||||
fp16_params["keep_io_types"] = args.io_block_list
|
||||
if args.node_block_list:
|
||||
fp16_params["node_block_list"] = args.node_block_list
|
||||
if args.op_block_list:
|
||||
fp16_params["op_block_list"] = args.op_block_list
|
||||
if args.force_fp16_initializers:
|
||||
fp16_params["force_fp16_initializers"] = args.force_fp16_initializers
|
||||
|
||||
is_io_float16 = args.precision == Precision.FLOAT16 and not args.keep_io_types
|
||||
|
||||
optimized_ops = ""
|
||||
all_ops = ""
|
||||
if args.optimize_onnx or args.precision != Precision.FLOAT32:
|
||||
output_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
|
||||
|
||||
logger.info(f"Optimizing model to {output_path}")
|
||||
m = gpt2helper.optimize_onnx(
|
||||
raw_onnx_model,
|
||||
output_path,
|
||||
args.precision == Precision.FLOAT16,
|
||||
model.config.num_attention_heads,
|
||||
model.config.hidden_size,
|
||||
args.use_external_data_format,
|
||||
auto_mixed_precision=args.auto_mixed_precision,
|
||||
stage=args.stage,
|
||||
**fp16_params,
|
||||
)
|
||||
|
||||
nodes = m.nodes()
|
||||
op_list = {node.op_type for node in nodes}
|
||||
all_ops = ",".join(op_list)
|
||||
|
||||
# print optimized operators
|
||||
optimized_op_counter = m.get_fused_operator_statistics()
|
||||
if optimized_op_counter:
|
||||
optimized_ops = ",".join([key for key in optimized_op_counter if optimized_op_counter[key] > 0])
|
||||
else:
|
||||
output_path = raw_onnx_model
|
||||
|
||||
if args.precision == Precision.INT8:
|
||||
logger.info("quantizing model...")
|
||||
QuantizeHelper.quantize_onnx_model(output_path, onnx_model_paths["int8"], args.use_external_data_format)
|
||||
model = QuantizeHelper.quantize_torch_model(model)
|
||||
logger.info("finished quantizing model")
|
||||
output_path = onnx_model_paths["int8"]
|
||||
|
||||
if args.output.endswith(".onnx") and output_path != args.output and not args.use_external_data_format:
|
||||
shutil.move(output_path, args.output)
|
||||
output_path = args.output
|
||||
|
||||
logger.info(f"Output path: {output_path}")
|
||||
model_size_in_MB = int(get_onnx_model_size(output_path, args.use_external_data_format) / 1024 / 1024) # noqa: N806
|
||||
|
||||
provider = args.provider
|
||||
session = create_onnxruntime_session(
|
||||
output_path, args.use_gpu, provider, enable_all_optimization=True, verbose=args.verbose
|
||||
)
|
||||
if args.model_class == "GPT2LMHeadModel" and session is not None:
|
||||
parity_result = gpt2helper.test_parity(
|
||||
session,
|
||||
model,
|
||||
device,
|
||||
is_io_float16,
|
||||
rtol=args.tolerance,
|
||||
atol=args.tolerance,
|
||||
model_class=args.model_class,
|
||||
has_position_ids=use_padding,
|
||||
has_attention_mask=use_padding,
|
||||
input_ids_dtype=int_data_type,
|
||||
position_ids_dtype=int_data_type,
|
||||
attention_mask_dtype=int_data_type,
|
||||
test_cases_per_run=args.test_cases,
|
||||
total_runs=args.test_runs,
|
||||
stage=args.stage,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
# An example configuration for testing performance
|
||||
batch_size = 8
|
||||
sequence_length = 32 if args.stage == 1 else 1
|
||||
past_sequence_length = 0 if args.stage == 1 else 32
|
||||
|
||||
latency = gpt2helper.test_performance(
|
||||
session,
|
||||
model,
|
||||
device,
|
||||
is_io_float16,
|
||||
total_runs=100,
|
||||
use_io_binding=True,
|
||||
model_class=args.model_class,
|
||||
has_position_ids=use_padding,
|
||||
has_attention_mask=use_padding,
|
||||
input_ids_dtype=int_data_type,
|
||||
position_ids_dtype=int_data_type,
|
||||
attention_mask_dtype=int_data_type,
|
||||
batch_size=batch_size,
|
||||
sequence_length=sequence_length,
|
||||
past_sequence_length=past_sequence_length,
|
||||
)
|
||||
|
||||
if args.precision == Precision.FLOAT16:
|
||||
logger.info(f"fp16 conversion parameters:{fp16_params}")
|
||||
|
||||
# Write results to file
|
||||
latency_name = get_latency_name(batch_size, sequence_length, past_sequence_length)
|
||||
csv_file_existed = os.path.exists(csv_filename)
|
||||
with open(csv_filename, mode="a", newline="") as csv_file:
|
||||
column_names = [
|
||||
"experiment",
|
||||
"run_id",
|
||||
"model_name",
|
||||
"model_class",
|
||||
"stage",
|
||||
"gpu",
|
||||
"precision",
|
||||
"optimizer",
|
||||
"test_cases",
|
||||
"runs",
|
||||
"keep_io_types",
|
||||
"io_block_list",
|
||||
"op_block_list",
|
||||
"node_block_list",
|
||||
"force_fp16_initializers",
|
||||
"auto_mixed_precision",
|
||||
"optimized_operators",
|
||||
"operators",
|
||||
"environment_variables",
|
||||
"onnxruntime",
|
||||
latency_name,
|
||||
"top1_match_rate",
|
||||
"onnx_size_in_MB",
|
||||
"diff_50_percentile",
|
||||
"diff_90_percentile",
|
||||
"diff_95_percentile",
|
||||
"diff_99_percentile",
|
||||
"diff_pass_rate",
|
||||
"nan_rate",
|
||||
"top1_match_rate_per_run",
|
||||
]
|
||||
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
||||
if not csv_file_existed:
|
||||
csv_writer.writeheader()
|
||||
row = {
|
||||
"experiment": experiment_name,
|
||||
"run_id": run_id,
|
||||
"model_name": args.model_name_or_path,
|
||||
"model_class": args.model_class,
|
||||
"stage": args.stage,
|
||||
"gpu": args.use_gpu,
|
||||
"precision": args.precision,
|
||||
"optimizer": args.optimize_onnx,
|
||||
"test_cases": args.test_cases,
|
||||
"runs": args.test_runs,
|
||||
"keep_io_types": args.keep_io_types,
|
||||
"io_block_list": args.io_block_list,
|
||||
"op_block_list": args.op_block_list,
|
||||
"node_block_list": args.node_block_list,
|
||||
"force_fp16_initializers": args.force_fp16_initializers,
|
||||
"auto_mixed_precision": args.auto_mixed_precision,
|
||||
"optimized_operators": optimized_ops,
|
||||
"operators": all_ops,
|
||||
"environment_variables": get_ort_environment_variables(),
|
||||
"onnxruntime": ort_version,
|
||||
latency_name: f"{latency:.2f}",
|
||||
"diff_50_percentile": parity_result["max_diff_percentile_50"],
|
||||
"diff_90_percentile": parity_result["max_diff_percentile_90"],
|
||||
"diff_95_percentile": parity_result["max_diff_percentile_95"],
|
||||
"diff_99_percentile": parity_result["max_diff_percentile_99"],
|
||||
"diff_pass_rate": parity_result["diff_pass_rate"],
|
||||
"nan_rate": parity_result["nan_rate"],
|
||||
"top1_match_rate": parity_result["top1_match_rate"],
|
||||
"top1_match_rate_per_run": parity_result["top1_match_rate_per_run"],
|
||||
"onnx_size_in_MB": f"{model_size_in_MB}",
|
||||
}
|
||||
logger.info(f"result: {row}")
|
||||
result.update(row)
|
||||
csv_writer.writerow(row)
|
||||
|
||||
if args.input_test_file:
|
||||
test_inputs = []
|
||||
# Each line of test file is a JSON string like:
|
||||
# {"input_ids": [[14698, 257, 1310, 13688, 319, 326]]}
|
||||
with open(args.input_test_file) as read_f:
|
||||
for _, line in enumerate(read_f):
|
||||
line = line.rstrip() # noqa: PLW2901
|
||||
data = json.loads(line)
|
||||
input_ids = torch.from_numpy(numpy.asarray(data["input_ids"], dtype=numpy.int64)).to(device)
|
||||
|
||||
if use_padding:
|
||||
if "attention_mask" in data:
|
||||
numpy_float = numpy.float16 if is_io_float16 else numpy.float32
|
||||
attention_mask = torch.from_numpy(numpy.asarray(data["attention_mask"], dtype=numpy_float)).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
padding = -1
|
||||
attention_mask = (input_ids != padding).type(torch.float16 if is_io_float16 else torch.float32)
|
||||
input_ids.masked_fill_(input_ids == padding, 0)
|
||||
|
||||
if "position_ids" in data:
|
||||
position_ids = torch.from_numpy(numpy.asarray(data["position_ids"], dtype=numpy.int64)).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(position_ids < 0, 0)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids.to(int_data_type),
|
||||
"position_ids": position_ids.to(int_data_type),
|
||||
"attention_mask": attention_mask.to(int_data_type),
|
||||
}
|
||||
else:
|
||||
inputs = {"input_ids": input_ids.to(int_data_type)}
|
||||
|
||||
test_inputs.append(inputs)
|
||||
|
||||
Gpt2Tester.test_generation(
|
||||
session,
|
||||
model,
|
||||
device,
|
||||
test_inputs,
|
||||
precision=args.precision,
|
||||
model_class=args.model_class,
|
||||
top_k=20,
|
||||
top_k_no_order=True,
|
||||
max_steps=24,
|
||||
max_inputs=0,
|
||||
verbose=args.verbose,
|
||||
save_test_data=3,
|
||||
save_test_data_dir=Path(output_path).parent,
|
||||
)
|
||||
|
||||
logger.info(f"Done. Output model: {output_path}")
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+1031
File diff suppressed because it is too large
Load Diff
+513
@@ -0,0 +1,513 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# This script uses different configurations in mixed precision conversion for GPT-2 model, and
|
||||
# measures the inference latency, top 1 match rate (compared to PyTorch FP32 model) and ONNX model size.
|
||||
# It outputs a csv file with Mann-Whitney U test and T-Test on each pair of experiments, where
|
||||
# pvalue < 0.05 means two experiments have significant difference on top 1 match rate.
|
||||
# User could use this script to select the best mixed precision model according to these metrics.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import onnx
|
||||
import scipy.stats
|
||||
from benchmark_helper import get_ort_environment_variables, setup_logger
|
||||
from convert_to_onnx import main
|
||||
from gpt2_helper import PRETRAINED_GPT2_MODELS, Gpt2Helper
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name_or_path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
required=False,
|
||||
type=str,
|
||||
default="gpt2_parity_results.csv",
|
||||
help="path of csv file to save the result",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test_cases",
|
||||
required=False,
|
||||
type=int,
|
||||
default=500,
|
||||
help="number of test cases per run",
|
||||
)
|
||||
|
||||
parser.add_argument("--runs", required=False, type=int, default=40, help="number of repeated runs")
|
||||
|
||||
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
|
||||
parser.set_defaults(use_gpu=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--all",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="run all combinations of mixed precision",
|
||||
)
|
||||
parser.set_defaults(all=False)
|
||||
|
||||
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
|
||||
parser.set_defaults(use_external_data_format=False)
|
||||
|
||||
parser.add_argument("--verbose", required=False, action="store_true")
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_test",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="do not run test, and only rank experiments based on existing csv file",
|
||||
)
|
||||
parser.set_defaults(skip_test=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Overwrite existing csv file",
|
||||
)
|
||||
parser.set_defaults(overwrite=False)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class ParityTask:
|
||||
def __init__(self, test_cases, total_runs, csv_path):
|
||||
self.total_runs = total_runs
|
||||
self.test_cases = test_cases
|
||||
self.csv_path = csv_path
|
||||
self.results = []
|
||||
self.run_id = 0
|
||||
|
||||
def run(self, argv, experiment_name):
|
||||
start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
run_id = f"{start_time}_{self.run_id}"
|
||||
self.run_id += 1
|
||||
|
||||
try:
|
||||
result = main(
|
||||
[*argv, "-t", f"{self.test_cases}", "-r", f"{self.total_runs}"],
|
||||
experiment_name=experiment_name,
|
||||
run_id=run_id,
|
||||
csv_filename=self.csv_path,
|
||||
)
|
||||
if result:
|
||||
self.results.append(result)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to run experiment {experiment_name}")
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_results_from_csv(csv_path):
|
||||
rows = []
|
||||
import csv # noqa: PLC0415
|
||||
|
||||
with open(csv_path, newline="") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
rows.append(row) # noqa: PERF402
|
||||
return rows
|
||||
|
||||
|
||||
def get_latency(row):
|
||||
for name in row:
|
||||
if name.startswith("average_latency(batch_size="):
|
||||
return float(row[name])
|
||||
|
||||
raise RuntimeError("Failed to get average_latency from output")
|
||||
|
||||
|
||||
def score(row):
|
||||
"""Scoring function based on 3 metrics. The larger score is better."""
|
||||
latency_in_ms = get_latency(row)
|
||||
top1_match_rate = float(row["top1_match_rate"])
|
||||
onnx_size_in_MB = float(row["onnx_size_in_MB"]) # noqa: N806
|
||||
# A simple scoring function: cost of 0.1ms latency ~ 0.1% match rate ~ 100MB size
|
||||
return top1_match_rate * 1000 - latency_in_ms * 10 - onnx_size_in_MB / 100
|
||||
|
||||
|
||||
def print_wins(wins, rows, test_name):
|
||||
print()
|
||||
print("*" * 10)
|
||||
|
||||
row_map = {}
|
||||
for row in rows:
|
||||
row_map[row["run_id"]] = row
|
||||
|
||||
sorted_wins = dict(
|
||||
sorted(
|
||||
wins.items(),
|
||||
key=lambda item: (item[1], score(row_map[item[0]])),
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
logger.debug(f"{test_name} Wins:{sorted_wins}")
|
||||
logger.info(f"Based on {test_name} wins and a scoring function, the ranking:")
|
||||
|
||||
rank = 0
|
||||
previous_value = -1
|
||||
for count, (key, value) in enumerate(sorted_wins.items()):
|
||||
if value != previous_value:
|
||||
rank = count
|
||||
previous_value = value
|
||||
|
||||
for row in rows:
|
||||
if row["run_id"] == key:
|
||||
logger.info(
|
||||
"{:02d}: WINs={:02d}, run_id={}, latency={:5.2f}, top1_match={:.4f}, size={}_MB, experiment={}, {}".format( # noqa: G001
|
||||
rank,
|
||||
value,
|
||||
key,
|
||||
get_latency(row),
|
||||
float(row["top1_match_rate"]),
|
||||
row["onnx_size_in_MB"],
|
||||
row["experiment"],
|
||||
get_ort_environment_variables(),
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
def run_significance_test(rows, output_csv_path):
|
||||
"""Run U test and T test."""
|
||||
utest_wins = {}
|
||||
ttest_wins = {}
|
||||
for row in rows:
|
||||
run_id = row["run_id"]
|
||||
utest_wins[run_id] = 0
|
||||
ttest_wins[run_id] = 0
|
||||
|
||||
with open(output_csv_path, "w", newline="") as csvfile:
|
||||
column_names = [
|
||||
"model_name",
|
||||
"run_id_1",
|
||||
"experiment_1",
|
||||
"top1_match_rate_1",
|
||||
"run_id_2",
|
||||
"experiment_2",
|
||||
"top1_match_rate_2",
|
||||
"U_statistic",
|
||||
"U_pvalue",
|
||||
"T_statistic",
|
||||
"T_pvalue",
|
||||
]
|
||||
|
||||
writer = csv.DictWriter(csvfile, fieldnames=column_names)
|
||||
writer.writeheader()
|
||||
|
||||
required_match_columns = ["model_name", "test_cases", "runs"]
|
||||
num_results = len(rows)
|
||||
for i in range(num_results - 1):
|
||||
result1 = rows[i]
|
||||
|
||||
if isinstance(result1["top1_match_rate_per_run"], str):
|
||||
a = json.loads(result1["top1_match_rate_per_run"])
|
||||
else:
|
||||
a = result1["top1_match_rate_per_run"]
|
||||
|
||||
for j in range(i + 1, num_results, 1):
|
||||
result2 = rows[j]
|
||||
|
||||
all_matched = True
|
||||
for column in required_match_columns:
|
||||
if result1[column] != result2[column]:
|
||||
all_matched = False
|
||||
break
|
||||
if not all_matched:
|
||||
continue
|
||||
|
||||
if isinstance(result2["top1_match_rate_per_run"], str):
|
||||
b = json.loads(result2["top1_match_rate_per_run"])
|
||||
else:
|
||||
b = result2["top1_match_rate_per_run"]
|
||||
|
||||
try:
|
||||
utest_statistic, utest_pvalue = scipy.stats.mannwhitneyu(
|
||||
a, b, use_continuity=True, alternative="two-sided"
|
||||
) # TODO: shall we use one-sided: less or greater according to "top1_match_rate"
|
||||
except ValueError: # ValueError: All numbers are identical in mannwhitneyu
|
||||
utest_statistic = None
|
||||
utest_pvalue = None
|
||||
ttest_statistic, ttest_pvalue = scipy.stats.ttest_ind(a, b, axis=None, equal_var=True)
|
||||
|
||||
if utest_pvalue is not None and utest_pvalue < 0.05:
|
||||
if float(result1["top1_match_rate"]) > float(result2["top1_match_rate"]):
|
||||
utest_wins[result1["run_id"]] += 1
|
||||
else:
|
||||
utest_wins[result2["run_id"]] += 1
|
||||
|
||||
if ttest_pvalue < 0.05:
|
||||
if float(result1["top1_match_rate"]) > float(result2["top1_match_rate"]):
|
||||
ttest_wins[result1["run_id"]] += 1
|
||||
else:
|
||||
ttest_wins[result2["run_id"]] += 1
|
||||
|
||||
row = {
|
||||
"model_name": result1["model_name"],
|
||||
"run_id_1": result1["run_id"],
|
||||
"experiment_1": result1["experiment"],
|
||||
"top1_match_rate_1": float(result1["top1_match_rate"]),
|
||||
"run_id_2": result2["run_id"],
|
||||
"experiment_2": result2["experiment"],
|
||||
"top1_match_rate_2": float(result2["top1_match_rate"]),
|
||||
"U_statistic": utest_statistic,
|
||||
"U_pvalue": utest_pvalue,
|
||||
"T_statistic": ttest_statistic,
|
||||
"T_pvalue": ttest_pvalue,
|
||||
}
|
||||
|
||||
writer.writerow(row)
|
||||
logger.info(f"U-Test and T-Test results are output to {output_csv_path}")
|
||||
print_wins(utest_wins, rows, "U-Test")
|
||||
print_wins(ttest_wins, rows, "T-Test")
|
||||
|
||||
|
||||
def get_last_matmul_node_name(raw_onnx_model: str):
|
||||
model = onnx.load(raw_onnx_model)
|
||||
onnx_model = OnnxModel(model)
|
||||
output_name_to_node = onnx_model.output_name_to_node()
|
||||
|
||||
assert model.graph.output[0].name in output_name_to_node
|
||||
node = output_name_to_node[model.graph.output[0].name]
|
||||
if node.op_type == "MatMul":
|
||||
logger.info(f"Found last MatMul node for logits: {node.name}")
|
||||
return node.name
|
||||
|
||||
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
|
||||
return None
|
||||
|
||||
|
||||
def get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list):
|
||||
model = args.model_name_or_path
|
||||
parameters = f"-m {model} -o --use_gpu -p fp16".split()
|
||||
if args.use_external_data_format:
|
||||
parameters.append("--use_external_data_format")
|
||||
parameters += [
|
||||
"--io_block_list",
|
||||
"logits",
|
||||
"--node_block_list",
|
||||
last_matmul_node_name,
|
||||
]
|
||||
|
||||
if op_block_list:
|
||||
parameters.extend(["--op_block_list", *op_block_list])
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def run_candidate(
|
||||
task: ParityTask,
|
||||
args,
|
||||
last_matmul_node_name,
|
||||
op_block_list=["FastGelu", "LayerNormalization"], # noqa: B006
|
||||
):
|
||||
parameters = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list)
|
||||
op_block_list_str = ",".join(sorted(op_block_list))
|
||||
|
||||
if op_block_list:
|
||||
name = f"Mixed precision baseline + {op_block_list_str} in FP32"
|
||||
else:
|
||||
name = f"Mixed precision baseline (logits output and last MatMul node {last_matmul_node_name} in FP32)"
|
||||
|
||||
env_vars = get_ort_environment_variables()
|
||||
if env_vars:
|
||||
name = name + f" ({env_vars})"
|
||||
|
||||
task.run(parameters, name)
|
||||
|
||||
|
||||
def get_baselines(args):
|
||||
model = args.model_name_or_path
|
||||
fp32_baseline = f"-m {model} -o -p fp32".split()
|
||||
if args.use_gpu:
|
||||
fp32_baseline.append("--use_gpu")
|
||||
if args.use_external_data_format:
|
||||
fp32_baseline.append("--use_external_data_format")
|
||||
|
||||
fp16_baseline = f"-m {model} -o --use_gpu -p fp16".split()
|
||||
if args.use_external_data_format:
|
||||
fp16_baseline.append("--use_external_data_format")
|
||||
|
||||
return fp32_baseline, fp16_baseline
|
||||
|
||||
|
||||
def run_tuning_step0(task, fp16_baseline, all_ops, optimized_ops):
|
||||
"""Step 0 is to check which operator in FP16 causes most loss"""
|
||||
fp32_logits = ["--io_block_list", "logits"]
|
||||
task.run(fp16_baseline + fp32_logits, "FP16 except logits")
|
||||
|
||||
fp32_io = ["--keep_io_types"]
|
||||
task.run(fp16_baseline + fp32_io, "Graph I/O FP32, Other FP16")
|
||||
|
||||
# Only weights in FP16
|
||||
task.run(
|
||||
fp16_baseline + fp32_io + ["--op_block_list"] + list(all_ops) + ["--force_fp16_initializers"],
|
||||
"FP32 except weights in FP16",
|
||||
)
|
||||
|
||||
optimized_ops_results = []
|
||||
op_list = optimized_ops
|
||||
for op in op_list:
|
||||
op_block_list = ["--op_block_list"] + [o for o in op_list if o != op]
|
||||
result = task.run(fp16_baseline + fp32_io + op_block_list, f"FP32 except {op} in FP16")
|
||||
if result:
|
||||
optimized_ops_results.append(result)
|
||||
|
||||
# Check which optimized operator causes the most loss in precision
|
||||
min_result = min(optimized_ops_results, key=lambda y: y["top1_match_rate"])
|
||||
print("step 0: optimized operator causes the most loss in precision", min_result)
|
||||
|
||||
|
||||
def run_tuning_step1(task, mixed_precision_baseline, optimized_ops):
|
||||
"""Step 1 is to figure out which optimized operator in FP32 could benefit most"""
|
||||
for op in optimized_ops:
|
||||
op_block_list = ["--op_block_list", op]
|
||||
task.run(
|
||||
mixed_precision_baseline + op_block_list,
|
||||
f"Mixed precision baseline + {op} in FP32",
|
||||
)
|
||||
|
||||
|
||||
def run_tuning_step2(task, mixed_precision_baseline, optimized_ops):
|
||||
"""Assumed that you have run step 0 and 1 to figure out that Logits FP32 and some operators shall be in FP32,
|
||||
This step will try add one more operator.
|
||||
"""
|
||||
candidate_fp32_ops = ["FastGelu", "LayerNormalization", "SkipLayerNormalization"]
|
||||
fp32_ops = [x for x in candidate_fp32_ops if x in optimized_ops]
|
||||
for op in optimized_ops:
|
||||
if op not in fp32_ops:
|
||||
op_block_list = [*fp32_ops, op]
|
||||
task.run(
|
||||
[*mixed_precision_baseline, "--op_block_list", *op_block_list],
|
||||
"Mixed precision baseline + {},{} in FP32".format(",".join(fp32_ops), op),
|
||||
)
|
||||
|
||||
|
||||
def run_parity(task: ParityTask, args):
|
||||
onnx_model_paths = Gpt2Helper.get_onnx_paths(
|
||||
"onnx_models",
|
||||
args.model_name_or_path,
|
||||
new_folder=args.use_external_data_format,
|
||||
remove_existing=[],
|
||||
)
|
||||
|
||||
fp32_baseline, fp16_baseline = get_baselines(args)
|
||||
|
||||
result = task.run(fp32_baseline, "FP32 baseline")
|
||||
|
||||
optimized_ops = []
|
||||
if result and ("optimized_operators" in result) and result["optimized_operators"]:
|
||||
optimized_ops = result["optimized_operators"].split(",")
|
||||
else:
|
||||
raise RuntimeError("Failed to get optimized operators")
|
||||
|
||||
all_ops = []
|
||||
if result and ("operators" in result) and result["operators"]:
|
||||
all_ops = result["operators"].split(",")
|
||||
else:
|
||||
raise RuntimeError("Failed to get operators")
|
||||
|
||||
# The following tests for fp16 requires GPU
|
||||
if not args.use_gpu:
|
||||
logger.info("skip mixed precision since --use_gpu is not specified")
|
||||
return
|
||||
|
||||
task.run(fp16_baseline, "FP16 baseline")
|
||||
|
||||
last_matmul_node_name = get_last_matmul_node_name(onnx_model_paths["raw"])
|
||||
|
||||
# Mixed precision baseline
|
||||
run_candidate(task, args, last_matmul_node_name, op_block_list=[])
|
||||
|
||||
def get_fp32_ops(x):
|
||||
return [op for op in x if op in all_ops]
|
||||
|
||||
if args.all:
|
||||
run_tuning_step0(task, fp16_baseline, all_ops, optimized_ops)
|
||||
mixed_precision_baseline = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list=[])
|
||||
run_tuning_step1(task, mixed_precision_baseline, optimized_ops)
|
||||
run_tuning_step2(task, mixed_precision_baseline, optimized_ops)
|
||||
else:
|
||||
run_candidate(
|
||||
task,
|
||||
args,
|
||||
last_matmul_node_name,
|
||||
op_block_list=get_fp32_ops(["SkipLayerNormalization", "LayerNormalization", "Add"]),
|
||||
)
|
||||
run_candidate(task, args, last_matmul_node_name, op_block_list=["FastGelu"])
|
||||
|
||||
# Run a few good candidates
|
||||
run_candidate(
|
||||
task,
|
||||
args,
|
||||
last_matmul_node_name,
|
||||
op_block_list=get_fp32_ops(["FastGelu", "SkipLayerNormalization", "LayerNormalization", "Add"]),
|
||||
)
|
||||
run_candidate(
|
||||
task,
|
||||
args,
|
||||
last_matmul_node_name,
|
||||
op_block_list=get_fp32_ops(
|
||||
["FastGelu", "EmbedLayerNormalization", "SkipLayerNormalization", "LayerNormalization", "Add"]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
setup_logger(args.verbose)
|
||||
|
||||
if args.test_cases < 100 or args.runs < 20 or args.test_cases * args.runs < 10000:
|
||||
logger.warning(
|
||||
"Not enough test cases or runs to get stable results or test significance. "
|
||||
"Recommend test_cases >= 100, runs >= 20, test_cases * runs >= 10000."
|
||||
)
|
||||
|
||||
if os.path.exists(args.csv) and not args.skip_test:
|
||||
if not args.overwrite:
|
||||
raise RuntimeError(
|
||||
f"Output file {args.csv} existed. Please remove the file, or use either --skip_test or --overwrite."
|
||||
)
|
||||
else:
|
||||
logger.info("Remove existing file %s since --overwrite is specified", args.csv)
|
||||
os.remove(args.csv)
|
||||
|
||||
task = ParityTask(args.test_cases, args.runs, args.csv)
|
||||
|
||||
if not args.skip_test:
|
||||
run_parity(task, args)
|
||||
|
||||
try:
|
||||
rows = load_results_from_csv(task.csv_path)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to load csv {task.csv_path}")
|
||||
rows = task.results
|
||||
|
||||
logger.info("Start running significance tests...")
|
||||
summary_csv = task.csv_path.replace(".csv", ".stats.csv")
|
||||
run_significance_test(rows, summary_csv)
|
||||
+501
@@ -0,0 +1,501 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
# This script helps evaluation of GPT-2 model.
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import statistics
|
||||
import timeit
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from benchmark_helper import Precision
|
||||
from gpt2_helper import Gpt2Helper, Gpt2Inputs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Gpt2Metric:
|
||||
def __init__(self, treatment_name, baseline_name="Torch", top_k=20):
|
||||
assert top_k > 1 and top_k <= 100
|
||||
self.baseline = baseline_name
|
||||
self.treatment = treatment_name
|
||||
self.name: str = f"{treatment_name} vs {baseline_name}"
|
||||
self.top_k = top_k
|
||||
self.top_1_error: int = 0
|
||||
self.top_k_error: int = 0
|
||||
self.total_samples: int = 0
|
||||
self.max_logits_diff: float = 0 # for non-empty past state
|
||||
self.max_logits_diff_no_past: float = 0 # for empty past state
|
||||
self.batch_top1_error: torch.FloatTensor = None # top 1 error for current batch
|
||||
self.batch_topk_error: torch.FloatTensor = None # top k error for current batch
|
||||
self.seq_len_latency = {}
|
||||
|
||||
def print(self):
|
||||
if self.baseline != self.treatment:
|
||||
print("---")
|
||||
print(f"Metrics for {self.treatment} (baseline={self.baseline}):")
|
||||
if self.total_samples > 0:
|
||||
top_1_error_rate = 100.0 * self.top_1_error / self.total_samples
|
||||
top_k_error_rate = 100.0 * self.top_k_error / self.total_samples
|
||||
print(
|
||||
f"Total={self.total_samples} Top1Error={self.top_1_error} ({top_1_error_rate:.2f}%) Top{self.top_k}Error={self.top_k_error} ({top_k_error_rate:.2f}%)"
|
||||
)
|
||||
print("Max logits diffs:")
|
||||
print(f"\twith past = {self.max_logits_diff:.6f}")
|
||||
print(f"\tempty past = {self.max_logits_diff_no_past:.6f}")
|
||||
else:
|
||||
print(f"Metrics for {self.treatment} (baseline):")
|
||||
|
||||
if self.seq_len_latency:
|
||||
print("Past sequence length range and average latency:")
|
||||
total = 0
|
||||
count = 0
|
||||
for key in sorted(self.seq_len_latency.keys()):
|
||||
average = statistics.mean(self.seq_len_latency[key]) * 1000.0
|
||||
if key == 0:
|
||||
print(f"\t{key}: \t{average:.2f} ms")
|
||||
else:
|
||||
print(f"\t[{2**key}, {2 ** (key + 1) - 1}]:\t{average:.2f} ms")
|
||||
total += average * len(self.seq_len_latency[key])
|
||||
count += len(self.seq_len_latency[key])
|
||||
print(f"Average Latency: {total / count:.2f} ms")
|
||||
|
||||
def diff_logits(self, baseline_logits, treatment_logits, is_empty_past: bool):
|
||||
diff = (baseline_logits - treatment_logits).abs().max()
|
||||
if is_empty_past:
|
||||
self.max_logits_diff_no_past = max(self.max_logits_diff_no_past, diff)
|
||||
else:
|
||||
self.max_logits_diff = max(self.max_logits_diff, diff)
|
||||
|
||||
return diff
|
||||
|
||||
def start_batch(self, batch_size: int):
|
||||
self.total_samples += batch_size
|
||||
self.batch_top1_error = torch.zeros((batch_size, 1), dtype=torch.bool)
|
||||
self.batch_topk_error = torch.zeros((batch_size, 1), dtype=torch.bool)
|
||||
|
||||
def eval_batch(self, baseline, treatment, past_seq_len, verbose=True):
|
||||
self._eval_topk(baseline.top_1_tokens, treatment.top_1_tokens, 1, verbose)
|
||||
self._eval_topk(baseline.top_k_tokens, treatment.top_k_tokens, self.top_k, verbose)
|
||||
|
||||
max_diff = self.diff_logits(baseline.logits, treatment.logits, past_seq_len == 0)
|
||||
if verbose:
|
||||
print(f"Max logits diffs of {self.name}: {max_diff}")
|
||||
|
||||
def _eval_topk(self, baseline_topk, treatment_topk, top_k, verbose=True):
|
||||
if not torch.all(torch.eq(baseline_topk, treatment_topk)):
|
||||
if top_k == 1:
|
||||
if verbose:
|
||||
print(f"Generated tokens not matched for {self.name}")
|
||||
self.batch_top1_error |= torch.eq(baseline_topk, treatment_topk).logical_not()
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results"
|
||||
)
|
||||
self.batch_topk_error |= (
|
||||
torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) > 0
|
||||
)
|
||||
|
||||
def end_batch(self):
|
||||
self.top_1_error += self.batch_top1_error.sum()
|
||||
self.top_k_error += self.batch_topk_error.sum()
|
||||
|
||||
def add_latency(self, past_seq_len, latency):
|
||||
key = int(math.log2(past_seq_len)) + 1 if past_seq_len > 0 else 0
|
||||
if key not in self.seq_len_latency:
|
||||
self.seq_len_latency[key] = []
|
||||
self.seq_len_latency[key].append(latency)
|
||||
|
||||
|
||||
class Gpt2Tester:
|
||||
def __init__(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
num_attention_heads,
|
||||
hidden_size,
|
||||
num_layer,
|
||||
device,
|
||||
is_fp16=False,
|
||||
top_k=20,
|
||||
top_k_required_order=False,
|
||||
):
|
||||
self.batch_size = input_ids.shape[0]
|
||||
self.input_length = input_ids.shape[1]
|
||||
self.n_layer = num_layer
|
||||
|
||||
self.input_ids = input_ids
|
||||
self.position_ids = position_ids
|
||||
self.attention_mask = attention_mask
|
||||
|
||||
self.has_position_ids = position_ids is not None
|
||||
self.has_attention_mask = attention_mask is not None
|
||||
|
||||
# Empty past state for first inference
|
||||
self.past = []
|
||||
past_shape = [
|
||||
2,
|
||||
self.batch_size,
|
||||
num_attention_heads,
|
||||
0,
|
||||
hidden_size // num_attention_heads,
|
||||
]
|
||||
for _i in range(num_layer):
|
||||
empty_past = torch.empty(past_shape).type(torch.float16 if is_fp16 else torch.float32)
|
||||
self.past.append(empty_past.to(device))
|
||||
|
||||
self.logits = None
|
||||
self.top_1_tokens = None
|
||||
self.top_k_tokens = None
|
||||
self.top_k = top_k
|
||||
self.top_k_required_order = top_k_required_order
|
||||
|
||||
def get_inputs(self) -> Gpt2Inputs:
|
||||
return Gpt2Inputs(self.input_ids, self.position_ids, self.attention_mask, self.past)
|
||||
|
||||
def save_test_data(self, session, output, save_test_data_dir, test_case_id):
|
||||
from onnx import numpy_helper # noqa: PLC0415
|
||||
|
||||
path = os.path.join(save_test_data_dir, "test_data_set_" + str(test_case_id))
|
||||
if os.path.exists(path):
|
||||
print(f"Directory {path} existed. Skip saving test data")
|
||||
return
|
||||
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
def add_tensor(input_tensors, torch_tensor, name):
|
||||
input_tensors.append(numpy_helper.from_array(torch_tensor.clone().cpu().numpy(), name))
|
||||
|
||||
input_tensors = []
|
||||
add_tensor(input_tensors, self.input_ids, "input_ids")
|
||||
|
||||
if self.has_position_ids:
|
||||
add_tensor(input_tensors, self.position_ids, "position_ids")
|
||||
|
||||
if self.has_attention_mask:
|
||||
add_tensor(input_tensors, self.attention_mask, "attention_mask")
|
||||
|
||||
for i in range(self.n_layer):
|
||||
add_tensor(input_tensors, self.past[i], "past_" + str(i))
|
||||
|
||||
for i, tensor in enumerate(input_tensors):
|
||||
with open(os.path.join(path, f"input_{i}.pb"), "wb") as f:
|
||||
f.write(tensor.SerializeToString())
|
||||
|
||||
output_names = [output.name for output in session.get_outputs()]
|
||||
for i, _name in enumerate(output_names):
|
||||
tensor = numpy_helper.from_array(
|
||||
output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy()
|
||||
)
|
||||
with open(os.path.join(path, f"output_{i}.pb"), "wb") as f:
|
||||
f.write(tensor.SerializeToString())
|
||||
|
||||
print(f"Test data saved to directory {path}")
|
||||
|
||||
def update(self, output, step, device):
|
||||
"""
|
||||
Update the inputs for next inference.
|
||||
"""
|
||||
self.logits = (
|
||||
torch.from_numpy(output[0]) if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu()
|
||||
)
|
||||
|
||||
self.top_1_tokens = Gpt2Tester.predict_next_token(self.logits)
|
||||
self.top_k_tokens = Gpt2Tester.predict_next_token(self.logits, self.top_k, self.top_k_required_order)
|
||||
|
||||
self.input_ids = self.top_1_tokens.clone().detach().reshape([self.batch_size, 1]).to(device)
|
||||
|
||||
if self.has_position_ids:
|
||||
self.position_ids = (
|
||||
torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, 1).to(device)
|
||||
)
|
||||
|
||||
if self.has_attention_mask:
|
||||
self.attention_mask = torch.cat(
|
||||
[
|
||||
self.attention_mask,
|
||||
torch.ones([self.batch_size, 1]).type_as(self.attention_mask),
|
||||
],
|
||||
1,
|
||||
).to(device)
|
||||
|
||||
self.past = []
|
||||
|
||||
if isinstance(output[1], tuple): # past in torch output is tuple
|
||||
self.past = list(output[1])
|
||||
else:
|
||||
for i in range(self.n_layer):
|
||||
past_i = (
|
||||
torch.from_numpy(output[i + 1])
|
||||
if isinstance(output[i + 1], numpy.ndarray)
|
||||
else output[i + 1].clone().detach()
|
||||
)
|
||||
self.past.append(past_i.to(device))
|
||||
|
||||
def diff(self, baseline):
|
||||
"""
|
||||
Compare inputs and logits output.
|
||||
"""
|
||||
|
||||
print("start diff...")
|
||||
if self.logits is not None:
|
||||
max_io_diff = (self.logits - baseline.logits).abs().max()
|
||||
if max_io_diff > 1e-4:
|
||||
print(f"Max logits difference is too large: {max_io_diff}")
|
||||
|
||||
if not torch.all(self.input_ids == baseline.input_ids):
|
||||
print("Input_ids is different", self.input_ids, baseline.input_ids)
|
||||
|
||||
if self.has_position_ids:
|
||||
if not torch.all(self.position_ids == baseline.position_ids):
|
||||
print(
|
||||
"position_ids is different",
|
||||
self.position_ids,
|
||||
baseline.position_ids,
|
||||
)
|
||||
|
||||
if self.has_attention_mask:
|
||||
if not torch.all(self.attention_mask == baseline.attention_mask):
|
||||
print(
|
||||
"attention_mask is different",
|
||||
self.attention_mask,
|
||||
baseline.attention_mask,
|
||||
)
|
||||
|
||||
assert len(self.past) == len(baseline.past)
|
||||
|
||||
for i, past_i in enumerate(self.past):
|
||||
assert past_i.shape == baseline.past[i].shape
|
||||
if past_i.nelement() > 0:
|
||||
max_past_diff = (past_i - baseline.past[i]).abs().max()
|
||||
if max_past_diff > 1e-4:
|
||||
print(f"max_past_diff[{i}]={max_past_diff}")
|
||||
|
||||
@staticmethod
|
||||
def predict_next_token(logits, top_k=1, required_order=False):
|
||||
"""
|
||||
Get top k topkens based on logits.
|
||||
"""
|
||||
|
||||
# logits has shape (batch_size, seq_len, vocab_size)
|
||||
# last token logits has shape (batch_size, vocab_size)
|
||||
lastTokenLogits = logits[:, -1] # noqa: N806
|
||||
if top_k == 1:
|
||||
generatedTokens = torch.argmax(lastTokenLogits, 1, True) # noqa: N806
|
||||
return generatedTokens
|
||||
else:
|
||||
topk = torch.argsort(lastTokenLogits, -1, descending=True)[:, :top_k]
|
||||
if not required_order:
|
||||
sorted_topk, _ = topk.sort()
|
||||
return sorted_topk
|
||||
return topk
|
||||
|
||||
@staticmethod
|
||||
def diff_present(onnx_output, onnx_io_output, n_layer):
|
||||
"""
|
||||
Compare the present outputs of two outputs from ONNX Runtime.
|
||||
"""
|
||||
present_diff_max = []
|
||||
for i in range(n_layer):
|
||||
onnx_present_i = (
|
||||
torch.from_numpy(onnx_output[i + 1])
|
||||
if isinstance(onnx_output[i + 1], numpy.ndarray)
|
||||
else onnx_output[i + 1]
|
||||
)
|
||||
onnx_io_present_i = (
|
||||
torch.from_numpy(onnx_io_output[i + 1])
|
||||
if isinstance(onnx_io_output[i + 1], numpy.ndarray)
|
||||
else onnx_io_output[i + 1]
|
||||
)
|
||||
max_diff = (onnx_present_i - onnx_io_present_i).abs().max()
|
||||
present_diff_max.append(max_diff)
|
||||
print(f"present_diff_max={present_diff_max}")
|
||||
|
||||
@staticmethod
|
||||
def is_quantized_onnx_model(onnx_model_path):
|
||||
"""
|
||||
Returns True if the ONNX model is quantized.
|
||||
"""
|
||||
from onnx import load # noqa: PLC0415
|
||||
|
||||
model = load(onnx_model_path)
|
||||
from onnxruntime.quantization.quantize import __producer__ as quantize_producer # noqa: PLC0415
|
||||
|
||||
return model.producer_name == quantize_producer
|
||||
|
||||
@staticmethod
|
||||
def test_generation(
|
||||
session,
|
||||
model,
|
||||
device,
|
||||
test_inputs,
|
||||
precision=Precision.FLOAT32,
|
||||
model_class="Gpt2LMHeadModel",
|
||||
top_k=20,
|
||||
top_k_no_order=True,
|
||||
max_steps=24,
|
||||
max_inputs=0,
|
||||
verbose=False,
|
||||
save_test_data=0,
|
||||
save_test_data_dir=".",
|
||||
):
|
||||
"""
|
||||
Test Generation using greedy beam search (without sampling) to compare PyTorch and ONNX model.
|
||||
It will print top 1 and top k errors on the given test inputs.
|
||||
"""
|
||||
print(
|
||||
f"start test generation: (top_k={top_k} top_k_no_order={top_k_no_order} max_steps={max_steps} test_inputs={len(test_inputs)} max_inputs={max_inputs})"
|
||||
)
|
||||
n_layer = model.config.n_layer
|
||||
n_head = model.config.n_head
|
||||
n_embd = model.config.n_embd
|
||||
eos_token_id = model.config.eos_token_id
|
||||
test_data_saved = 0
|
||||
|
||||
is_float16 = precision == Precision.FLOAT16
|
||||
if is_float16:
|
||||
assert "float16" in session.get_outputs()[0].type
|
||||
|
||||
# We will still use fp32 torch model as baseline when onnx model if fp16
|
||||
model.eval().to(device)
|
||||
|
||||
# Allocate initial buffers for IO Binding of ONNX Runtimne. The buffer size will automatically increase later.
|
||||
init_output_shapes = Gpt2Helper.get_output_shapes(
|
||||
batch_size=4,
|
||||
past_sequence_length=128,
|
||||
sequence_length=32,
|
||||
config=model.config,
|
||||
model_class=model_class,
|
||||
)
|
||||
output_buffers = Gpt2Helper.get_output_buffers(init_output_shapes, device, is_float16=is_float16)
|
||||
|
||||
baseline_name = "Torch"
|
||||
treatment_name = "Quantized Onnx" if precision == Precision.INT8 else "Onnx"
|
||||
torch_metric = Gpt2Metric(baseline_name, baseline_name, top_k)
|
||||
onnx_metric = Gpt2Metric(treatment_name, baseline_name, top_k)
|
||||
onnx_io_metric = Gpt2Metric(treatment_name + " with IO Binding", baseline_name, top_k)
|
||||
|
||||
for i, inputs in enumerate(test_inputs):
|
||||
if max_inputs > 0 and i == max_inputs:
|
||||
break
|
||||
if i % 10 == 0:
|
||||
print(f"{i}")
|
||||
input_ids = inputs["input_ids"]
|
||||
position_ids = inputs.get("position_ids", None)
|
||||
attention_mask = inputs.get("attention_mask", None)
|
||||
|
||||
onnx_runner = Gpt2Tester(
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
n_head,
|
||||
n_embd,
|
||||
n_layer,
|
||||
device,
|
||||
is_float16,
|
||||
top_k,
|
||||
not top_k_no_order,
|
||||
)
|
||||
onnx_io_runner = Gpt2Tester(
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
n_head,
|
||||
n_embd,
|
||||
n_layer,
|
||||
device,
|
||||
is_float16,
|
||||
top_k,
|
||||
not top_k_no_order,
|
||||
)
|
||||
torch_runner = Gpt2Tester(
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
n_head,
|
||||
n_embd,
|
||||
n_layer,
|
||||
device,
|
||||
False,
|
||||
top_k,
|
||||
not top_k_no_order,
|
||||
) # Torch model baseline is fp32
|
||||
|
||||
batch_size = torch_runner.batch_size
|
||||
onnx_metric.start_batch(batch_size)
|
||||
onnx_io_metric.start_batch(batch_size)
|
||||
|
||||
with torch.no_grad():
|
||||
done = torch.zeros(batch_size, dtype=torch.bool)
|
||||
for step in range(max_steps):
|
||||
seq_len = list(onnx_runner.input_ids.size())[1]
|
||||
past_seq_len = list(onnx_runner.past[0].size())[3]
|
||||
|
||||
start_time = timeit.default_timer()
|
||||
pytorch_output = Gpt2Helper.pytorch_inference(model, torch_runner.get_inputs())
|
||||
torch_metric.add_latency(past_seq_len, timeit.default_timer() - start_time)
|
||||
torch_runner.update(pytorch_output, step, device)
|
||||
|
||||
onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference(
|
||||
session, onnx_runner.get_inputs(), total_runs=1
|
||||
)
|
||||
onnx_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
|
||||
onnx_runner.update(onnx_output, step, device)
|
||||
|
||||
output_shapes = Gpt2Helper.get_output_shapes(
|
||||
batch_size,
|
||||
past_seq_len,
|
||||
seq_len,
|
||||
model.config,
|
||||
model_class=model_class,
|
||||
)
|
||||
Gpt2Helper.auto_increase_buffer_size(output_buffers, output_shapes)
|
||||
|
||||
(
|
||||
onnx_io_output,
|
||||
avg_latency_ms,
|
||||
) = Gpt2Helper.onnxruntime_inference_with_binded_io(
|
||||
session,
|
||||
onnx_io_runner.get_inputs(),
|
||||
output_buffers,
|
||||
output_shapes,
|
||||
total_runs=1,
|
||||
return_numpy=False,
|
||||
include_copy_output_latency=True,
|
||||
)
|
||||
onnx_io_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
|
||||
|
||||
if test_data_saved < save_test_data:
|
||||
onnx_io_runner.save_test_data(session, onnx_io_output, save_test_data_dir, test_data_saved)
|
||||
test_data_saved += 1
|
||||
|
||||
onnx_io_runner.update(onnx_io_output, step, device)
|
||||
|
||||
if verbose:
|
||||
onnx_runner.diff(onnx_io_runner)
|
||||
Gpt2Tester.diff_present(onnx_output, onnx_io_output, n_layer)
|
||||
|
||||
print("Top 1 tokens:")
|
||||
print("\tTorch", torch_runner.top_1_tokens)
|
||||
print("\tONNX", onnx_runner.top_1_tokens)
|
||||
print("\tONNX with IO binding", onnx_io_runner.top_1_tokens)
|
||||
|
||||
onnx_metric.eval_batch(torch_runner, onnx_runner, past_seq_len, verbose=verbose)
|
||||
onnx_io_metric.eval_batch(torch_runner, onnx_io_runner, past_seq_len, verbose=verbose)
|
||||
|
||||
done = done | (torch_runner.top_1_tokens == eos_token_id).any()
|
||||
if torch.all(done):
|
||||
break
|
||||
|
||||
onnx_metric.end_batch()
|
||||
onnx_io_metric.end_batch()
|
||||
|
||||
torch_metric.print()
|
||||
onnx_metric.print()
|
||||
onnx_io_metric.print()
|
||||
+146
@@ -0,0 +1,146 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
# This script helps debugging parity issue for two same onnx models with fp16 and fp32 format
|
||||
# Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON
|
||||
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from benchmark_helper import create_onnxruntime_session
|
||||
from gpt2_helper import Gpt2Helper
|
||||
from onnx import TensorProto, numpy_helper
|
||||
|
||||
NON_ZERO_VALUE = str(1)
|
||||
ZERO_VALUE = str(0)
|
||||
|
||||
|
||||
def environ_setting_nodes(node_name_filter=None, node_type_filter=None):
|
||||
# Set I/O data as default
|
||||
os.environ["ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA"] = ZERO_VALUE
|
||||
os.environ["ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA"] = NON_ZERO_VALUE
|
||||
os.environ["ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA"] = NON_ZERO_VALUE
|
||||
if node_name_filter is not None:
|
||||
os.environ["ORT_DEBUG_NODE_IO_NAME_FILTER"] = node_name_filter
|
||||
elif node_type_filter is not None:
|
||||
os.environ["ORT_DEBUG_NODE_IO_OP_TYPE_FILTER"] = node_type_filter
|
||||
else:
|
||||
os.environ["ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK"] = NON_ZERO_VALUE
|
||||
|
||||
|
||||
def environ_setting_paths(output_path):
|
||||
# Set dumping values to files as default
|
||||
os.environ["ORT_DEBUG_NODE_IO_DUMP_DATA_DESTINATION"] = "files"
|
||||
os.environ["ORT_DEBUG_NODE_IO_OUTPUT_DIR"] = output_path
|
||||
|
||||
|
||||
def environ_reset():
|
||||
for flag in [
|
||||
"ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA",
|
||||
"ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA",
|
||||
"ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA",
|
||||
"ORT_DEBUG_NODE_IO_NAME_FILTER",
|
||||
"ORT_DEBUG_NODE_IO_OP_TYPE_FILTER",
|
||||
"ORT_DEBUG_NODE_IO_DUMP_DATA_TO_FILES",
|
||||
"ORT_DEBUG_NODE_IO_OUTPUT_DIR",
|
||||
"ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK",
|
||||
]:
|
||||
if flag in os.environ:
|
||||
del os.environ[flag]
|
||||
|
||||
|
||||
def inference(model_path, dummy_inputs, outputs_path, use_gpu):
|
||||
environ_reset()
|
||||
environ_setting_nodes()
|
||||
environ_setting_paths(outputs_path)
|
||||
session = create_onnxruntime_session(model_path, use_gpu, enable_all_optimization=False)
|
||||
Gpt2Helper.onnxruntime_inference(session, dummy_inputs)
|
||||
|
||||
|
||||
def generate_outputs_files(model_path, dummy_inputs, outputs_path, use_gpu):
|
||||
dir_path = Path(outputs_path)
|
||||
if dir_path.exists() and dir_path.is_dir():
|
||||
import shutil # noqa: PLC0415
|
||||
|
||||
shutil.rmtree(outputs_path)
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
process = multiprocessing.Process(target=inference, args=(model_path, dummy_inputs, outputs_path, use_gpu))
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
|
||||
def post_processing(outputs_path, outputs_path_other):
|
||||
# Compare outputs with e.g. fp16 and fp32
|
||||
record = {}
|
||||
if_close = {}
|
||||
|
||||
import glob # noqa: PLC0415
|
||||
|
||||
for filename in glob.glob(os.path.join(outputs_path, "*.tensorproto")):
|
||||
filename_other = os.path.join(outputs_path_other, Path(filename).name)
|
||||
if not os.path.exists(filename_other):
|
||||
continue
|
||||
with open(filename, "rb") as f:
|
||||
tensor = TensorProto()
|
||||
tensor.ParseFromString(f.read())
|
||||
array = numpy_helper.to_array(tensor)
|
||||
with open(filename_other, "rb") as f: # noqa: PLW2901
|
||||
tensor_other = TensorProto()
|
||||
tensor_other.ParseFromString(f.read())
|
||||
array_other = numpy_helper.to_array(tensor_other)
|
||||
if array_other.size == 0:
|
||||
continue
|
||||
diff = numpy.average(numpy.abs(array_other - array) / (numpy.abs(array_other) + 1e-6))
|
||||
if math.isnan(diff):
|
||||
continue
|
||||
record[Path(filename).name.split(".")[0]] = diff
|
||||
if_close[Path(filename).name.split(".")[0]] = numpy.allclose(array, array_other, rtol=1e-04, atol=1e-04)
|
||||
|
||||
results = ["Node\tDiff\tClose"]
|
||||
for k, v in sorted(record.items(), key=lambda x: x[1], reverse=True):
|
||||
results.append(f"{k}\t{v}\t{if_close[k]}")
|
||||
for line in results:
|
||||
print(line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Below example shows how to use this helper to investigate parity issue of gpt-2 fp32 and fp16 onnx model
|
||||
# Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON !!
|
||||
multiprocessing.set_start_method("spawn")
|
||||
|
||||
# Generate Inputs
|
||||
sequence_length = 8
|
||||
past_sequence_length = 8
|
||||
batch_size = 5
|
||||
dummy_inputs_fp16 = Gpt2Helper.get_dummy_inputs(
|
||||
batch_size,
|
||||
past_sequence_length,
|
||||
sequence_length,
|
||||
12,
|
||||
768,
|
||||
12,
|
||||
50257,
|
||||
device=torch.device("cpu"),
|
||||
float16=True,
|
||||
)
|
||||
dummy_inputs_fp32 = dummy_inputs_fp16.to_fp32()
|
||||
|
||||
# Get GPT-2 model from huggingface using convert_to_onnx.py
|
||||
os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp32.onnx -o -p fp32 --use_gpu")
|
||||
os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp16.onnx -o -p fp16 --use_gpu")
|
||||
|
||||
# Specify the directory to dump the node's I/O
|
||||
outputs_path_fp32_gpu = "./fp32_gpu"
|
||||
outputs_path_fp16_gpu = "./fp16_gpu"
|
||||
generate_outputs_files("./gpt2_fp32.onnx", dummy_inputs_fp32, outputs_path_fp32_gpu, use_gpu=True)
|
||||
generate_outputs_files("./gpt2_fp16.onnx", dummy_inputs_fp16, outputs_path_fp16_gpu, use_gpu=True)
|
||||
|
||||
# Compare each node's I/O value and sort based on average rtol
|
||||
post_processing(outputs_path_fp16_gpu, outputs_path_fp32_gpu)
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+703
@@ -0,0 +1,703 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import argparse
|
||||
import datetime
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import psutil
|
||||
import torch
|
||||
from benchmark_helper import measure_memory, setup_logger
|
||||
from dist_settings import get_rank, get_size
|
||||
from llama_inputs import (
|
||||
add_io_bindings_as_ortvalues,
|
||||
get_merged_sample_with_past_kv_inputs,
|
||||
get_msft_sample_inputs,
|
||||
get_sample_inputs,
|
||||
get_sample_with_past_kv_inputs,
|
||||
verify_ort_inputs,
|
||||
)
|
||||
from optimum.onnxruntime import ORTModelForCausalLM
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from tqdm import trange
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
|
||||
def get_ort_model_inputs_len(args, model):
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
||||
return 0
|
||||
if args.benchmark_type == "hf-ort":
|
||||
try:
|
||||
# New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
|
||||
return len(model.inputs_names)
|
||||
except Exception:
|
||||
# Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
|
||||
return len(model.decoder.input_names)
|
||||
return len(model.get_inputs())
|
||||
|
||||
|
||||
def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
|
||||
init_inputs, iter_inputs = None, None
|
||||
|
||||
# For past_present_share_buffer:
|
||||
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
|
||||
# Set max_seq_len to config value for other models
|
||||
max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
|
||||
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
||||
init_inputs = get_sample_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
args.sequence_length,
|
||||
return_dict=True,
|
||||
)
|
||||
iter_inputs = get_sample_with_past_kv_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
args.sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
elif args.benchmark_type in {"hf-ort"}:
|
||||
if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
|
||||
# Using split models in Optimum (e.g. created by Optimum export)
|
||||
init_inputs = get_sample_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
args.sequence_length,
|
||||
return_dict=True,
|
||||
)
|
||||
iter_inputs = get_sample_with_past_kv_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
args.sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
# Using merged model in Optimum (e.g. created by convert_to_onnx export)
|
||||
init_inputs = get_merged_sample_with_past_kv_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
seq_len=args.sequence_length,
|
||||
past_seq_len=0,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
engine="pt",
|
||||
return_dict=True,
|
||||
)
|
||||
iter_inputs = get_merged_sample_with_past_kv_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
seq_len=1,
|
||||
past_seq_len=args.sequence_length,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
engine="pt",
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
elif args.benchmark_type == "ort-convert-to-onnx":
|
||||
# Microsoft export from convert_to_onnx
|
||||
init_inputs = get_merged_sample_with_past_kv_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
seq_len=args.sequence_length,
|
||||
past_seq_len=0,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
engine="ort",
|
||||
return_dict=True,
|
||||
world_size=args.world_size,
|
||||
)
|
||||
iter_inputs = get_merged_sample_with_past_kv_inputs(
|
||||
args.config,
|
||||
args.target_device,
|
||||
args.batch_size,
|
||||
seq_len=1,
|
||||
past_seq_len=args.sequence_length,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
engine="ort",
|
||||
return_dict=True,
|
||||
world_size=args.world_size,
|
||||
)
|
||||
|
||||
elif args.benchmark_type == "ort-msft":
|
||||
# Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
|
||||
|
||||
init_inputs = get_msft_sample_inputs(
|
||||
args.config,
|
||||
args.batch_size,
|
||||
past_seq_len=0,
|
||||
seq_len=args.sequence_length,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
split_kv=split_kv,
|
||||
)
|
||||
iter_inputs = get_msft_sample_inputs(
|
||||
args.config,
|
||||
args.batch_size,
|
||||
past_seq_len=args.sequence_length,
|
||||
seq_len=1,
|
||||
max_seq_len=max_seq_len,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
split_kv=split_kv,
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Unable to auto-detect inputs for provided model")
|
||||
|
||||
return init_inputs, iter_inputs
|
||||
|
||||
|
||||
def get_model(args: argparse.Namespace):
|
||||
model, sess_options = None, None
|
||||
start_time, end_time = None, None
|
||||
|
||||
# There are multiple sources that the model could come from:
|
||||
# 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
|
||||
# 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
|
||||
# 3) Benchmark LLaMA-2 from local download of model
|
||||
# 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
|
||||
# 5) Benchmark LLaMA-2 from convert_to_onnx
|
||||
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
||||
source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
|
||||
start_time = time.time()
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
source,
|
||||
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
|
||||
use_auth_token=args.auth,
|
||||
trust_remote_code=args.auth,
|
||||
use_cache=True,
|
||||
cache_dir=args.cache_dir,
|
||||
).to(args.target_device)
|
||||
end_time = time.time()
|
||||
|
||||
if args.benchmark_type == "hf-pt-compile":
|
||||
model = torch.compile(model)
|
||||
|
||||
elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_profiling = args.profile
|
||||
if args.verbose:
|
||||
sess_options.log_verbosity_level = 1
|
||||
sess_options.log_severity_level = 1
|
||||
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Optimum export or convert_to_onnx.py export
|
||||
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
|
||||
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
|
||||
|
||||
decoder_file_name = None
|
||||
decoder_with_past_file_name = None
|
||||
for filename in os.listdir(args.hf_ort_dir_path):
|
||||
if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
|
||||
continue
|
||||
if "decoder_model" in filename or filename == "model.onnx":
|
||||
decoder_file_name = filename
|
||||
if "decoder_with_past_model" in filename:
|
||||
decoder_with_past_file_name = filename
|
||||
if "decoder_merged_model" in filename:
|
||||
decoder_file_name = filename
|
||||
decoder_with_past_file_name = filename
|
||||
|
||||
start_time = time.time()
|
||||
model = ORTModelForCausalLM.from_pretrained(
|
||||
args.hf_ort_dir_path,
|
||||
decoder_file_name=decoder_file_name,
|
||||
decoder_with_past_file_name=decoder_with_past_file_name,
|
||||
use_auth_token=args.auth,
|
||||
trust_remote_code=args.auth,
|
||||
use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
|
||||
use_merged=(True if decoder_file_name == "model.onnx" else None),
|
||||
provider=provider,
|
||||
provider_options=provider_options,
|
||||
session_options=sess_options,
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
|
||||
# Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
|
||||
start_time = time.time()
|
||||
model = ort.InferenceSession(
|
||||
args.ort_model_path.format(args.rank),
|
||||
sess_options,
|
||||
providers=[args.execution_provider],
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"Loaded model in {end_time - start_time} s")
|
||||
return model
|
||||
|
||||
|
||||
def time_fn(args, fn, inputs):
|
||||
# Warm up
|
||||
warmup_range = (
|
||||
range(args.warmup_runs)
|
||||
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
|
||||
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
|
||||
)
|
||||
|
||||
if args.verbose:
|
||||
outputs = fn(inputs)
|
||||
logger.info(outputs)
|
||||
|
||||
input_sync = lambda *kwargs: ( # noqa: E731
|
||||
args.io_binding.synchronize_inputs()
|
||||
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
|
||||
else lambda *kwargs: (
|
||||
torch.cuda.synchronize()
|
||||
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
|
||||
else lambda *kwargs: None
|
||||
)
|
||||
) # no-op function
|
||||
|
||||
output_sync = lambda *kwargs: ( # noqa: E731
|
||||
args.io_binding.synchronize_outputs()
|
||||
if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
|
||||
else lambda *kwargs: (
|
||||
torch.cuda.synchronize()
|
||||
if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
|
||||
else lambda *kwargs: None
|
||||
)
|
||||
) # no-op function
|
||||
|
||||
for _ in warmup_range:
|
||||
input_sync()
|
||||
fn(inputs)
|
||||
output_sync()
|
||||
|
||||
# Benchmark
|
||||
total_time = 0
|
||||
bench_range = (
|
||||
range(args.num_runs)
|
||||
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
|
||||
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
|
||||
)
|
||||
for _ in bench_range:
|
||||
input_sync()
|
||||
start_time = time.time()
|
||||
|
||||
fn(inputs)
|
||||
|
||||
output_sync()
|
||||
end_time = time.time()
|
||||
|
||||
total_time += end_time - start_time
|
||||
|
||||
# Newline print after trange in order to print metrics on new lines without progress bar on same line
|
||||
if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
|
||||
logger.info("")
|
||||
|
||||
latency = total_time / args.num_runs
|
||||
throughput = args.batch_size / latency
|
||||
|
||||
if args.rank == 0:
|
||||
logger.info(f"Batch Size: {args.batch_size}")
|
||||
logger.info(f"Sequence Length: {args.sequence_length}")
|
||||
logger.info(f"Latency: {latency} s")
|
||||
logger.info(f"Throughput: {throughput} tps")
|
||||
return
|
||||
|
||||
|
||||
def profile_fn(args, fn, inputs, inputs_type):
|
||||
# Filename prefix format:
|
||||
# "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
|
||||
prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
|
||||
filename = None
|
||||
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
||||
# Profile PyTorch kernels
|
||||
with profile( # noqa: SIM117
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("model_inference"):
|
||||
fn(inputs)
|
||||
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
|
||||
|
||||
filename = os.path.join(args.log_folder, f"{prefix}.log")
|
||||
with open(filename, "w") as f:
|
||||
f.write(prof_data)
|
||||
|
||||
else:
|
||||
# Profile ORT kernels
|
||||
fn(inputs)
|
||||
|
||||
# Set new log name for ORT profile log generated
|
||||
filename = f"{prefix}.json"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def measure_fn(args, fn, inputs):
|
||||
# Measure CPU usage
|
||||
pid = os.getpid()
|
||||
process = psutil.Process(pid)
|
||||
process.cpu_percent(interval=0.1)
|
||||
|
||||
fn(inputs)
|
||||
if args.rank == 0:
|
||||
logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
|
||||
|
||||
# Measure memory usage
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
|
||||
|
||||
# Flush output so memory usage is printed
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def run_hf_inference(args, init_inputs, iter_inputs, model):
|
||||
# Inference steps to measure
|
||||
def get_logits(inputs):
|
||||
# Inference pass without decoding
|
||||
outputs = model(**inputs)
|
||||
return outputs
|
||||
|
||||
# Examples of other inference steps that can be measured:
|
||||
# To use, uncomment the function and assign it to `generate_fn`
|
||||
|
||||
# def get_pred_ids(inputs):
|
||||
# # Inference pass with predicted token ids generation
|
||||
# predicted_ids = model.generate(**inputs)
|
||||
# return predicted_ids
|
||||
|
||||
# def gen_and_dec(inputs):
|
||||
# # Inference pass with generation and decoding
|
||||
# predicted_ids = get_pred_ids(inputs)
|
||||
# transcription = []
|
||||
# for bs in range(args.batch_size):
|
||||
# for rs in range(args.num_return_sequences):
|
||||
# transcription.append(
|
||||
# args.tokenizer.batch_decode(
|
||||
# predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
|
||||
# )[0]
|
||||
# )
|
||||
# return transcription
|
||||
|
||||
generate_fn = get_logits
|
||||
|
||||
if args.benchmark_type == "hf-pt-compile":
|
||||
# Run forward pass once with each set of inputs to process through Dynamo
|
||||
generate_fn(init_inputs)
|
||||
generate_fn(iter_inputs)
|
||||
|
||||
if args.profile:
|
||||
new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Turn profiling off to stop appending to log
|
||||
old_logname = model.decoder.session.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Turn profiling off to stop appending to log
|
||||
old_logname = model.decoder_with_past.session.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
return
|
||||
|
||||
# PyTorch evaluations
|
||||
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
|
||||
time_fn(args, generate_fn, init_inputs)
|
||||
measure_fn(args, generate_fn, init_inputs)
|
||||
|
||||
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
|
||||
time_fn(args, generate_fn, iter_inputs)
|
||||
measure_fn(args, generate_fn, iter_inputs)
|
||||
|
||||
|
||||
def run_ort_inference(args, init_inputs, iter_inputs, model):
|
||||
def prepare_ort_inputs(inputs, kv_cache_ortvalues):
|
||||
# Verify model inputs
|
||||
inputs = verify_ort_inputs(model, inputs)
|
||||
|
||||
# Add IO bindings for non-CPU execution providers
|
||||
if args.device != "cpu":
|
||||
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
|
||||
model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
|
||||
)
|
||||
setattr(args, "io_binding", io_binding) # noqa: B010
|
||||
return io_binding, kv_cache_ortvalues
|
||||
|
||||
return inputs, kv_cache_ortvalues
|
||||
|
||||
def with_io_binding(io_binding):
|
||||
# Inference pass with IO binding
|
||||
model.run_with_iobinding(io_binding)
|
||||
|
||||
def without_io_binding(inputs):
|
||||
# Inference pass without IO binding
|
||||
outputs = model.run(None, inputs)
|
||||
return outputs
|
||||
|
||||
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
|
||||
kv_cache_ortvalues = {}
|
||||
|
||||
if args.profile:
|
||||
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
|
||||
new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
|
||||
|
||||
# Turn profiling off to stop appending to log file
|
||||
old_logname = model.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
# Re-initialize model for new log file instead of appending to old log file
|
||||
model = get_model(args)
|
||||
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
|
||||
new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
|
||||
|
||||
# Turn profiling off to stop appending to log
|
||||
old_logname = model.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
return
|
||||
|
||||
# ORT evaluations
|
||||
logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
|
||||
ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
|
||||
time_fn(args, generate_fn, ort_init_inputs)
|
||||
measure_fn(args, generate_fn, ort_init_inputs)
|
||||
|
||||
logger.info("\nEvaluating `model(inputs)` step with past_key_values")
|
||||
ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
|
||||
time_fn(args, generate_fn, ort_iter_inputs)
|
||||
measure_fn(args, generate_fn, ort_iter_inputs)
|
||||
|
||||
|
||||
def run_inference(args, init_inputs, iter_inputs, model):
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
|
||||
run_hf_inference(args, init_inputs, iter_inputs, model)
|
||||
elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
|
||||
run_ort_inference(args, init_inputs, iter_inputs, model)
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
|
||||
def get_args(rank=0):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-bt",
|
||||
"--benchmark-type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"hf-pt-eager",
|
||||
"hf-pt-compile",
|
||||
"hf-ort",
|
||||
"ort-msft",
|
||||
"ort-convert-to-onnx",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
|
||||
)
|
||||
|
||||
# Args for choosing the model
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
required=True,
|
||||
type=str,
|
||||
default="fp32",
|
||||
choices=["int4", "int8", "fp16", "fp32"],
|
||||
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-pt-dir-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-ort-dir-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ort-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to ONNX model",
|
||||
)
|
||||
|
||||
# Args for running and evaluating the model
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-sizes",
|
||||
default="1 2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sequence-lengths",
|
||||
default="32 64 128 256 512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
)
|
||||
parser.add_argument("-id", "--device-id", type=int, default=0)
|
||||
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
|
||||
parser.add_argument("-n", "--num-runs", type=int, default=10)
|
||||
parser.add_argument("--seed", type=int, default=2)
|
||||
|
||||
# Args for decoding logic
|
||||
parser.add_argument("--max-length", type=int, default=32)
|
||||
parser.add_argument("--num-return-sequences", type=int, default=1)
|
||||
|
||||
# Args for accessing detailed info
|
||||
parser.add_argument("--profile", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
|
||||
)
|
||||
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
|
||||
parser.add_argument("--verbose", default=False, action="store_true")
|
||||
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
|
||||
parser.add_argument(
|
||||
"--cache-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
default="./model_cache",
|
||||
help="Cache dir where Hugging Face files are stored",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set seed properties
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Set runtime properties
|
||||
if "ort" in args.benchmark_type:
|
||||
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
|
||||
if args.execution_provider == "CUDAExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": rank})
|
||||
elif args.execution_provider == "ROCMExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": rank})
|
||||
args.device = "cuda"
|
||||
|
||||
# Check that paths have been specified for any benchmarking with ORT
|
||||
if args.benchmark_type == "hf-ort":
|
||||
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
|
||||
if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
|
||||
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
|
||||
|
||||
args.batch_sizes = args.batch_sizes.split(" ")
|
||||
args.sequence_lengths = args.sequence_lengths.split(" ")
|
||||
|
||||
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
||||
args.precision = (
|
||||
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
|
||||
)
|
||||
|
||||
# Check that only one (batch_size, sequence_length) combination is set for profiling
|
||||
if args.profile:
|
||||
assert len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1, (
|
||||
"Please provide only one (batch_size, sequence_length) combination for profiling"
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
rank = get_rank()
|
||||
world_size = get_size()
|
||||
|
||||
args = get_args(rank)
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
args.rank = rank
|
||||
args.world_size = world_size
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
|
||||
)
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
|
||||
)
|
||||
target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
|
||||
use_fp16 = args.precision == "fp16"
|
||||
|
||||
setattr(args, "tokenizer", tokenizer) # noqa: B010
|
||||
setattr(args, "config", config) # noqa: B010
|
||||
setattr(args, "target_device", target_device) # noqa: B010
|
||||
setattr(args, "use_fp16", use_fp16) # noqa: B010
|
||||
|
||||
# Get model and model info
|
||||
model = get_model(args)
|
||||
ort_model_inputs_len = get_ort_model_inputs_len(args, model)
|
||||
|
||||
# Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
|
||||
if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
|
||||
onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
|
||||
gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
|
||||
|
||||
use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
|
||||
setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
|
||||
else:
|
||||
setattr(args, "use_buffer_share", False) # noqa: B010
|
||||
|
||||
# Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
|
||||
for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
|
||||
if args.rank == 0:
|
||||
logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
|
||||
setattr(args, "batch_size", int(batch_size)) # noqa: B010
|
||||
setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
|
||||
|
||||
init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
|
||||
run_inference(args, init_inputs, iter_inputs, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+488
@@ -0,0 +1,488 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from metrics import BenchmarkRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-sizes",
|
||||
type=str,
|
||||
default="1 2",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sequence-lengths",
|
||||
type=str,
|
||||
default="8 16 32 64 128 256 512",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--warmup-runs",
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=1000,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-pt-eager",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Benchmark in PyTorch without `torch.compile`",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-pt-compile",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Benchmark in PyTorch with `torch.compile`",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-ort-dir-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ort-msft-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to ONNX model from https://github.com/microsoft/Llama-2-Onnx",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ort-convert-to-onnx-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to ONNX model from convert_to_onnx",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache-dir",
|
||||
type=str,
|
||||
default="./model_cache",
|
||||
help="Cache dir where Hugging Face files are stored",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name in Hugging Face",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["int4", "int8", "fp16", "fp32"],
|
||||
help="Precision to run model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
help="Device to benchmark models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device-id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="GPU device ID",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Print detailed logs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of mins to attempt the benchmark before moving on",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to folder to save logs and results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
|
||||
log_folder_name = f"./{args.model_size}_{args.precision}"
|
||||
if not args.log_folder:
|
||||
args.log_folder = log_folder_name
|
||||
os.makedirs(args.log_folder, exist_ok=True)
|
||||
|
||||
# Convert timeout value to secs
|
||||
args.timeout *= 60
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def process_log_file(device_id, log_file, base_results):
|
||||
entries = []
|
||||
batch_size, sequence_length, step = None, None, None
|
||||
latency_s, latency_ms, throughput, memory = None, None, None, None
|
||||
|
||||
batch_pattern = "Batch Size: "
|
||||
sequence_pattern = "Sequence Length: "
|
||||
prompt_step_pattern = "to get past_key_values"
|
||||
per_token_step_pattern = "with past_key_values"
|
||||
latency_pattern = "Latency: "
|
||||
throughput_pattern = "Throughput: "
|
||||
memory_pattern = "peak="
|
||||
|
||||
with open(log_file) as f:
|
||||
for input_line in f:
|
||||
line = input_line.replace("\n", "")
|
||||
|
||||
if batch_pattern in line:
|
||||
batch_size = int(line[len(batch_pattern) :])
|
||||
elif sequence_pattern in line:
|
||||
sequence_length = int(line[len(sequence_pattern) :])
|
||||
elif prompt_step_pattern in line:
|
||||
step = "prompt"
|
||||
elif per_token_step_pattern in line:
|
||||
step = "per-token"
|
||||
elif latency_pattern in line:
|
||||
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
|
||||
latency_ms = latency_s * 1000
|
||||
elif throughput_pattern in line:
|
||||
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
|
||||
elif memory_pattern in line:
|
||||
if "CPU" in line:
|
||||
# Example format for log entry:
|
||||
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
|
||||
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
|
||||
else:
|
||||
# Example format for log entry:
|
||||
# GPU memory usage: before=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 69637.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}] peak=[{'device_id': 0, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 73861.25}, {'device_id': 1, 'name': 'NVIDIA A100-SXM4-80GB', 'max_used_MB': 890.625}]
|
||||
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
|
||||
usage = json.loads(peak)[device_id]["max_used_MB"]
|
||||
memory = float(usage) / 1000
|
||||
|
||||
# Append log entry to list of entries
|
||||
entry = base_results + [ # noqa: RUF005
|
||||
batch_size,
|
||||
sequence_length,
|
||||
step,
|
||||
latency_s,
|
||||
latency_ms,
|
||||
throughput,
|
||||
memory,
|
||||
]
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def save_results(results, filename):
|
||||
import pandas as pd # noqa: PLC0415
|
||||
|
||||
df = pd.DataFrame(
|
||||
results,
|
||||
columns=[
|
||||
"Warmup Runs",
|
||||
"Measured Runs",
|
||||
"Model Name",
|
||||
"Engine",
|
||||
"Precision",
|
||||
"Device",
|
||||
"Batch Size",
|
||||
"Sequence Length",
|
||||
"Step",
|
||||
"Latency (s)",
|
||||
"Latency (ms)",
|
||||
"Throughput (tps)",
|
||||
"Memory (GB)",
|
||||
],
|
||||
)
|
||||
|
||||
# Set column types
|
||||
df["Warmup Runs"] = df["Warmup Runs"].astype("int")
|
||||
df["Measured Runs"] = df["Measured Runs"].astype("int")
|
||||
df["Batch Size"] = df["Batch Size"].astype("int")
|
||||
df["Sequence Length"] = df["Sequence Length"].astype("int")
|
||||
df["Latency (s)"] = df["Latency (s)"].astype("float")
|
||||
df["Latency (ms)"] = df["Latency (ms)"].astype("float")
|
||||
df["Throughput (tps)"] = df["Throughput (tps)"].astype("float")
|
||||
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
|
||||
|
||||
# get package name and version
|
||||
import pkg_resources # noqa: PLC0415
|
||||
|
||||
installed_packages = pkg_resources.working_set
|
||||
installed_packages_list = sorted(
|
||||
[f"{i.key}=={i.version}" for i in installed_packages if i.key in ["onnxruntime", "onnxruntime-gpu"]]
|
||||
)
|
||||
|
||||
ort_pkg_name = ""
|
||||
ort_pkg_version = ""
|
||||
if installed_packages_list:
|
||||
ort_pkg_name = installed_packages_list[0].split("==")[0]
|
||||
ort_pkg_version = installed_packages_list[0].split("==")[1]
|
||||
|
||||
# Save results to csv with standard format
|
||||
records = []
|
||||
for _, row in df.iterrows():
|
||||
if row["Engine"] in ["optimum-ort", "onnxruntime"]:
|
||||
record = BenchmarkRecord(
|
||||
row["Model Name"], row["Precision"], "onnxruntime", row["Device"], ort_pkg_name, ort_pkg_version
|
||||
)
|
||||
elif row["Engine"] in ["pytorch-eager", "pytorch-compile"]:
|
||||
record = BenchmarkRecord(
|
||||
row["Model Name"], row["Precision"], "pytorch", row["Device"], torch.__name__, torch.__version__
|
||||
)
|
||||
else:
|
||||
record = BenchmarkRecord(row["Model Name"], row["Precision"], row["Engine"], row["Device"], "", "")
|
||||
record.config.warmup_runs = row["Warmup Runs"]
|
||||
record.config.measured_runs = row["Measured Runs"]
|
||||
record.config.batch_size = row["Batch Size"]
|
||||
record.config.seq_length = row["Sequence Length"]
|
||||
record.config.customized["measure_step"] = row["Step"]
|
||||
record.config.customized["engine"] = row["Engine"]
|
||||
record.metrics.customized["latency_s_mean"] = row["Latency (s)"]
|
||||
record.metrics.latency_ms_mean = row["Latency (ms)"]
|
||||
record.metrics.customized["throughput_tps"] = row["Throughput (tps)"]
|
||||
record.metrics.max_memory_usage_GB = row["Memory (GB)"]
|
||||
|
||||
records.append(record)
|
||||
|
||||
BenchmarkRecord.save_as_csv(filename, records)
|
||||
BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
|
||||
logger.info(f"Results saved in {filename}!")
|
||||
|
||||
|
||||
def benchmark(args, benchmark_cmd, engine):
|
||||
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
|
||||
log_path = os.path.join(args.log_folder, log_filename)
|
||||
with open(log_path, "w") as log_file:
|
||||
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
|
||||
try:
|
||||
process.wait(args.timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
# Create entries for csv
|
||||
logger.info("Gathering data from log files...")
|
||||
base_results = [args.warmup_runs, args.num_runs, args.model_name, engine, args.precision, args.device]
|
||||
results = process_log_file(args.device_id, log_path, base_results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
all_results = []
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device_id)
|
||||
|
||||
# Benchmark PyTorch without torch.compile
|
||||
if args.hf_pt_eager:
|
||||
benchmark_cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"models.llama.benchmark",
|
||||
"--benchmark-type",
|
||||
"hf-pt-eager",
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
"--cache-dir",
|
||||
args.cache_dir,
|
||||
"--auth",
|
||||
]
|
||||
logger.info("Benchmark PyTorch without torch.compile")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch-eager")
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark PyTorch with torch.compile
|
||||
if args.hf_pt_compile:
|
||||
benchmark_cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"models.llama.benchmark",
|
||||
"--benchmark-type",
|
||||
"hf-pt-compile",
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
"--cache-dir",
|
||||
args.cache_dir,
|
||||
"--auth",
|
||||
]
|
||||
logger.info("Benchmark PyTorch with torch.compile")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch-compile")
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark Optimum + ONNX Runtime
|
||||
if args.hf_ort_dir_path:
|
||||
benchmark_cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"models.llama.benchmark",
|
||||
"--benchmark-type",
|
||||
"hf-ort",
|
||||
"--hf-ort-dir-path",
|
||||
args.hf_ort_dir_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
"--cache-dir",
|
||||
args.cache_dir,
|
||||
"--auth",
|
||||
]
|
||||
logger.info("Benchmark Optimum + ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "optimum-ort")
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark Microsoft model in ONNX Runtime
|
||||
if args.ort_msft_model_path:
|
||||
benchmark_cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"models.llama.benchmark",
|
||||
"--benchmark-type",
|
||||
"ort-msft",
|
||||
"--ort-model-path",
|
||||
args.ort_msft_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
"--cache-dir",
|
||||
args.cache_dir,
|
||||
]
|
||||
logger.info("Benchmark Microsoft model in ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "ort-msft")
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark convert_to_onnx model in ONNX Runtime
|
||||
if args.ort_convert_to_onnx_model_path:
|
||||
benchmark_cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"models.llama.benchmark",
|
||||
"--benchmark-type",
|
||||
"ort-convert-to-onnx",
|
||||
"--ort-model-path",
|
||||
args.ort_convert_to_onnx_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--batch-sizes",
|
||||
args.batch_sizes,
|
||||
"--sequence-lengths",
|
||||
args.sequence_lengths,
|
||||
"--device",
|
||||
args.device,
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
"--cache-dir",
|
||||
args.cache_dir,
|
||||
]
|
||||
logger.info("Benchmark convert_to_onnx model in ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "onnxruntime")
|
||||
all_results.extend(results)
|
||||
|
||||
csv_file = f"{args.model_size}_{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
|
||||
save_results(all_results, os.path.join(args.log_folder, csv_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+608
@@ -0,0 +1,608 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# This is an end-to-end benchmarking script for the Hugging Face LLaMA-2 model.
|
||||
#
|
||||
# Prerequisites:
|
||||
# 1) Install `huggingface-cli`:
|
||||
#
|
||||
# $ pip install huggingface_hub
|
||||
#
|
||||
# 2) Authenticate with Hugging Face's CLI:
|
||||
#
|
||||
# $ huggingface-cli login
|
||||
#
|
||||
# 3) Accept Meta's license in Hugging Face to access the models at https://huggingface.co/meta-llama/
|
||||
#
|
||||
# 4) Install the latest ONNX Runtime version
|
||||
#
|
||||
# $ pip install onnxruntime-gpu
|
||||
#
|
||||
# 5) Install flash attention v2
|
||||
#
|
||||
# $ pip install flash-attn --no-build-isolation
|
||||
#
|
||||
# 6) Install bitsandbytes
|
||||
#
|
||||
# $ pip install bitsandbytes
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import gc
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import textwrap
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_model(args: argparse.Namespace):
|
||||
if args.benchmark_type in {"pt-eager", "pt-compile"}:
|
||||
model = None
|
||||
if args.onnx_precision == "int4" and args.device == "cuda":
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
||||
cache_dir=args.cache_dir,
|
||||
torch_dtype=args.torch_dtype,
|
||||
use_auth_token=args.auth,
|
||||
trust_remote_code=args.trust,
|
||||
use_cache=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
quantization_config=bnb_config,
|
||||
max_memory={args.device_id: "80GB"},
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
||||
cache_dir=args.cache_dir,
|
||||
torch_dtype=args.torch_dtype,
|
||||
use_auth_token=args.auth,
|
||||
trust_remote_code=args.trust,
|
||||
use_cache=True,
|
||||
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
|
||||
).to(args.target_device)
|
||||
except Exception as e:
|
||||
# When flash_attention or sdpa doesn't support a model, it throws an exception.
|
||||
# Rather than stopping a process, run as eager mode.
|
||||
print("Try to load a model using eager mode: ", e)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
||||
cache_dir=args.cache_dir,
|
||||
torch_dtype=args.torch_dtype,
|
||||
use_auth_token=args.auth,
|
||||
trust_remote_code=args.trust,
|
||||
use_cache=True,
|
||||
attn_implementation="eager",
|
||||
).to(args.target_device)
|
||||
|
||||
model.eval()
|
||||
|
||||
if args.benchmark_type == "pt-compile":
|
||||
model = torch.compile(model)
|
||||
|
||||
else:
|
||||
sess_options = ort.SessionOptions()
|
||||
ep = (
|
||||
("CUDAExecutionProvider", {"device_id": args.device_id})
|
||||
if args.device == "cuda"
|
||||
else "CPUExecutionProvider"
|
||||
)
|
||||
model = ort.InferenceSession(args.onnx_model_path, sess_options=sess_options, providers=[ep])
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def run_inference(args, model, runs, inputs, outputs):
|
||||
if args.benchmark_type == "pt-compile":
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Synchronize inputs
|
||||
io_binding = None
|
||||
if args.benchmark_type in {"pt-eager", "pt-compile"}:
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize(args.target_device)
|
||||
else:
|
||||
io_binding = add_io_bindings_as_tensors(model, inputs, outputs, args.use_fp16, args.use_buffer_share)
|
||||
io_binding.synchronize_inputs()
|
||||
|
||||
# Run inference
|
||||
start = time.perf_counter()
|
||||
for _ in range(runs):
|
||||
if args.benchmark_type in {"pt-eager", "pt-compile"}:
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize(args.target_device)
|
||||
else:
|
||||
model.run_with_iobinding(io_binding)
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
end = time.perf_counter()
|
||||
avg = (end - start) / runs
|
||||
return avg, outputs
|
||||
|
||||
|
||||
def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt):
|
||||
clear_cache()
|
||||
inputs, outputs = get_initial_inputs_and_outputs(
|
||||
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine
|
||||
)
|
||||
_, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs)
|
||||
return inputs, outputs
|
||||
|
||||
|
||||
def clear_cache():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def save_results(results, filename, gen_length):
|
||||
df = pd.DataFrame(
|
||||
results,
|
||||
columns=[
|
||||
"Batch Size",
|
||||
"Prompt Length",
|
||||
"Prompt Processing Latency (ms)",
|
||||
"Prompt Processing Throughput (tps)",
|
||||
"Sampling Latency (ms)",
|
||||
"Sampling Throughput (tps)",
|
||||
"First Token Generated Latency (ms)",
|
||||
"First Token Generated Throughput (tps)",
|
||||
f"Average Latency of First {gen_length // 2} Tokens Generated (ms)",
|
||||
f"Average Throughput of First {gen_length // 2} Tokens Generated (tps)",
|
||||
f"Average Latency of First {gen_length} Tokens Generated (ms)",
|
||||
f"Average Throughput of First {gen_length} Tokens Generated (tps)",
|
||||
"Wall-Clock Latency (s)",
|
||||
"Wall-Clock Throughput (tps)",
|
||||
],
|
||||
)
|
||||
|
||||
df.to_csv(filename, index=False)
|
||||
logger.info(f"Results saved in {filename}!")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-bt",
|
||||
"--benchmark-type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["pt-eager", "pt-compile", "ort"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--auth",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use Hugging Face authentication token to access model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--trust",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hugging Face Hub in their own modeling files",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--cache-dir",
|
||||
type=str,
|
||||
default=os.path.join(".", "model_cache"),
|
||||
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(model_name, cache_dir=cache_dir)`.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-dir-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all Hugging Face files (e.g. config, tokenizer, PyTorch model). Use when loading model as `AutoModel.from_pretrained(folder_path)`.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--onnx-model-path",
|
||||
required=False,
|
||||
help="Path to ONNX model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--prompts-file",
|
||||
required=True,
|
||||
default=os.path.join(".", "models", "llama", "prompts.json"),
|
||||
help="JSON file containing entries in the format 'prompt length: prompt' where prompt length = tokenized length of prompt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_buffer_share",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use when GroupQueryAttention (GQA) is in ONNX model",
|
||||
)
|
||||
|
||||
(
|
||||
parser.add_argument(
|
||||
"--anomaly-filtering",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use this flag to filter anomaly accelerator times for tokens generated. \
|
||||
This may give more accurate latency and throughput metrics for tokens generated. \
|
||||
Wall-clock metrics are still reported with anomaly times though.",
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-sizes",
|
||||
default="1 2",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--prompt-lengths",
|
||||
default="16 64 256 1024",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
required=True,
|
||||
type=str,
|
||||
default="fp32",
|
||||
choices=["int4", "int8", "fp16", "fp32"],
|
||||
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--generation-length",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of new tokens to generate",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
choices=["cpu", "cuda"],
|
||||
)
|
||||
|
||||
parser.add_argument("-id", "--device-id", type=int, default=0)
|
||||
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
|
||||
parser.add_argument("-n", "--num-runs", type=int, default=100)
|
||||
parser.add_argument("--seed", type=int, default=2)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set seed properties
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
# Set runtime properties
|
||||
if "ort" in args.benchmark_type:
|
||||
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
|
||||
if args.execution_provider == "CUDAExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
|
||||
# Check that paths have been specified for any benchmarking with ORT
|
||||
if args.benchmark_type == "ort":
|
||||
assert args.onnx_model_path, "Please specify a path to `--onnx-model-path`"
|
||||
|
||||
args.batch_sizes = args.batch_sizes.split(" ")
|
||||
args.prompt_lengths = args.prompt_lengths.split(" ")
|
||||
|
||||
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
||||
setattr(args, "onnx_precision", args.precision) # noqa: B010
|
||||
args.precision = (
|
||||
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
|
||||
)
|
||||
|
||||
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
|
||||
torch_dtype = torch.float16 if args.precision == "fp16" else torch.float32
|
||||
engine = "ort" if args.benchmark_type == "ort" else "pt"
|
||||
setattr(args, "target_device", target_device) # noqa: B010
|
||||
setattr(args, "torch_dtype", torch_dtype) # noqa: B010
|
||||
setattr(args, "engine", engine) # noqa: B010
|
||||
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
|
||||
|
||||
args.use_buffer_share = args.use_buffer_share and engine == "ort"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
setup_logger(False)
|
||||
logger.info(args.__dict__)
|
||||
|
||||
# Get prompts and prompt sizes
|
||||
size_to_prompt = None
|
||||
with open(args.prompts_file) as f:
|
||||
size_to_prompt = json.load(f, object_hook=lambda d: {int(k): v for k, v in d.items()})
|
||||
|
||||
# Get config, tokenizer, and model
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
||||
cache_dir=args.cache_dir,
|
||||
use_auth_token=args.auth,
|
||||
trust_remote_code=args.trust,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
|
||||
cache_dir=args.cache_dir,
|
||||
use_auth_token=args.auth,
|
||||
trust_remote_code=args.trust,
|
||||
)
|
||||
model = get_model(args)
|
||||
|
||||
all_csv_metrics = []
|
||||
for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths):
|
||||
batch_size, prompt_length = int(batch_size), int(prompt_length) # noqa: PLW2901
|
||||
logger.info(f"Running batch size = {batch_size}, prompt length = {prompt_length}")
|
||||
clear_cache()
|
||||
max_length = prompt_length + args.generation_length
|
||||
|
||||
if prompt_length not in size_to_prompt:
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
A prompt of size {prompt_length} was not found in '{args.prompts_file}'. There are a couple of solutions to fix this.
|
||||
1) You can change one of the keys in '{args.prompts_file}' to be {prompt_length}.
|
||||
If {prompt_length} < actual prompt's length, the benchmark E2E tool will repeat the first word in the prompt until {prompt_length} = actual prompt's length.
|
||||
If {prompt_length} > actual prompt's length, the benchmark E2E tool will automatically trim the actual prompt's length so that {prompt_length} = actual prompt's length.
|
||||
2) You can add a new key-value entry in '{args.prompts_file}' of the form '{prompt_length}': 'your prompt goes here'.
|
||||
"""
|
||||
)
|
||||
)
|
||||
prompt = [size_to_prompt[prompt_length]] * batch_size
|
||||
csv_metrics = [batch_size, prompt_length]
|
||||
|
||||
try:
|
||||
# Measure prompt processing
|
||||
logger.info("Measuring prompt processing...")
|
||||
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
|
||||
accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs)
|
||||
|
||||
# Calculate prompt metrics
|
||||
accelerator_prompt_latency_ms = accelerator_prompt_latency_s * 1000
|
||||
accelerator_prompt_thrpt = batch_size * (prompt_length / accelerator_prompt_latency_s)
|
||||
logger.info(f"Average Latency of Prompt Processing: {accelerator_prompt_latency_ms} ms")
|
||||
logger.info(
|
||||
f"Average Throughput of Prompt Processing: {batch_size * (prompt_length / accelerator_prompt_latency_s)} tps"
|
||||
)
|
||||
csv_metrics.extend([accelerator_prompt_latency_ms, accelerator_prompt_thrpt])
|
||||
|
||||
# Measure token generation
|
||||
logger.info("Measuring token generation...")
|
||||
clear_cache()
|
||||
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
|
||||
|
||||
all_token_ids = inputs["input_ids"].clone()
|
||||
current_length = all_token_ids.shape[-1]
|
||||
num_heads = config.num_key_value_heads
|
||||
head_size = (
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
has_eos = torch.zeros(batch_size, device=args.target_device, dtype=torch.bool)
|
||||
|
||||
# 0th entry will have prompt accelerator time, 1st entry onwards will have token generation accelerator time
|
||||
accelerator_times = []
|
||||
sampling_times = [] # cost to sample after each model run
|
||||
|
||||
wall_clock_start_time = time.perf_counter()
|
||||
while current_length <= max_length:
|
||||
# Run inference
|
||||
accelerator_time_latency_s, outputs = run_inference(args, model, 1, inputs, outputs)
|
||||
accelerator_times.append(accelerator_time_latency_s)
|
||||
|
||||
# Sample with argmax (greedy search)
|
||||
sampling_start_time = time.perf_counter()
|
||||
if outputs["logits"].shape[1] > 1:
|
||||
prompt_end_indices = inputs["attention_mask"].sum(1) - 1
|
||||
idxs = (
|
||||
prompt_end_indices.unsqueeze(dim=1)
|
||||
.repeat(1, config.vocab_size)
|
||||
.view(batch_size, 1, config.vocab_size)
|
||||
)
|
||||
next_token_logits = torch.gather(outputs["logits"], 1, idxs).squeeze()
|
||||
else:
|
||||
next_token_logits = outputs["logits"][:, -1, :]
|
||||
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
||||
|
||||
# Check if we previously reached EOS token id or if generated token id is EOS token id
|
||||
has_eos = has_eos | next_tokens == tokenizer.eos_token_id
|
||||
|
||||
# Determine which new tokens to add to list of all token ids
|
||||
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
|
||||
tokens_to_add = next_tokens.masked_fill(has_eos, tokenizer.eos_token_id).reshape([batch_size, 1])
|
||||
sampling_end_time = time.perf_counter()
|
||||
sampling_times.append(sampling_end_time - sampling_start_time)
|
||||
|
||||
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
|
||||
current_length += 1
|
||||
|
||||
# Update inputs for next inference run
|
||||
inputs["input_ids"] = tokens_to_add
|
||||
inputs["attention_mask"] = torch.cat(
|
||||
[inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
|
||||
)
|
||||
if "position_ids" in inputs:
|
||||
inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
|
||||
|
||||
# Set logits to zeros for next inference run and re-use memory buffer
|
||||
if outputs["logits"].shape[1] != 1:
|
||||
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
|
||||
outputs["logits"].zero_()
|
||||
|
||||
# Update KV caches for next inference run
|
||||
if args.engine == "pt":
|
||||
# Update KV caches for PyTorch
|
||||
inputs["past_key_values"] = outputs["past_key_values"]
|
||||
elif not args.use_buffer_share:
|
||||
# Update KV caches for ONNX Runtime if buffer sharing is not used
|
||||
for i in range(config.num_hidden_layers):
|
||||
inputs[f"past_key_values.{i}.key"] = outputs[f"present.{i}.key"]
|
||||
inputs[f"past_key_values.{i}.value"] = outputs[f"present.{i}.value"]
|
||||
|
||||
new_sequence_length = inputs["attention_mask"].shape[1]
|
||||
for i in range(config.num_hidden_layers):
|
||||
present_key = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
new_sequence_length,
|
||||
head_size,
|
||||
device=args.target_device,
|
||||
dtype=args.torch_dtype,
|
||||
)
|
||||
present_value = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
new_sequence_length,
|
||||
head_size,
|
||||
device=args.target_device,
|
||||
dtype=args.torch_dtype,
|
||||
)
|
||||
outputs.update(
|
||||
{
|
||||
f"present.{i}.key": present_key.contiguous(),
|
||||
f"present.{i}.value": present_value.contiguous(),
|
||||
}
|
||||
)
|
||||
|
||||
wall_clock_end_time = time.perf_counter()
|
||||
|
||||
# Filter out any anomaly accelerator times (e.g. for `torch.compile`)
|
||||
accelerator_times.pop(0) # Remove prompt processing time
|
||||
if args.anomaly_filtering:
|
||||
anomaly_threshold_factor = 10
|
||||
min_time_s = min(accelerator_times)
|
||||
orig_size = len(accelerator_times)
|
||||
accelerator_times = list(
|
||||
filter(lambda acc_time: acc_time < anomaly_threshold_factor * min_time_s, accelerator_times)
|
||||
)
|
||||
new_size = len(accelerator_times)
|
||||
logger.info(
|
||||
f"Filtered out {orig_size - new_size} anomaly accelerator times that are {anomaly_threshold_factor}x greater than {min_time_s * 1000} ms..."
|
||||
)
|
||||
|
||||
#######################################################
|
||||
# Calculate sampling and first token generated metrics
|
||||
#######################################################
|
||||
|
||||
# Calculate sampling metrics
|
||||
avg_sampling_latency_s = sum(sampling_times) / len(sampling_times)
|
||||
avg_sampling_latency_ms = avg_sampling_latency_s * 1000
|
||||
avg_sampling_thrpt = batch_size * (1 / avg_sampling_latency_s)
|
||||
logger.info(f"Average Latency of Sampling: {avg_sampling_latency_ms} ms")
|
||||
logger.info(f"Average Throughput of Sampling: {avg_sampling_thrpt} tps")
|
||||
|
||||
# Calculate first token generated metrics
|
||||
first_token_latency_s = accelerator_times[0]
|
||||
first_token_latency_ms = first_token_latency_s * 1000
|
||||
first_token_thrpt = batch_size * (1 / first_token_latency_s)
|
||||
logger.info(f"Latency of First Token Generated: {first_token_latency_ms} ms")
|
||||
logger.info(f"Throughput of First Token Generated: {first_token_thrpt} tps")
|
||||
|
||||
####################################################
|
||||
# Calculate first `halfway` token generated metrics
|
||||
####################################################
|
||||
|
||||
halfway = args.generation_length // 2
|
||||
halfway_token_latency_s = sum(accelerator_times[:halfway]) / len(accelerator_times[:halfway])
|
||||
halfway_token_latency_ms = halfway_token_latency_s * 1000
|
||||
halfway_token_thrpt = batch_size * (1 / halfway_token_latency_s)
|
||||
logger.info(f"Average Latency of First {halfway} Tokens Generated: {halfway_token_latency_ms} ms")
|
||||
logger.info(f"Average Throughput of First {halfway} Tokens Generated: {halfway_token_thrpt} tps")
|
||||
|
||||
#########################################
|
||||
# Calculate all tokens generated metrics
|
||||
#########################################
|
||||
|
||||
all_token_latency_s = sum(accelerator_times) / len(accelerator_times)
|
||||
all_token_latency_ms = all_token_latency_s * 1000
|
||||
all_token_thrpt = batch_size * (1 / all_token_latency_s)
|
||||
logger.info(
|
||||
f"Average Latency of First {args.generation_length} Tokens Generated: {all_token_latency_ms} ms"
|
||||
)
|
||||
logger.info(f"Average Throughput of First {args.generation_length} Tokens Generated: {all_token_thrpt} tps")
|
||||
|
||||
###############################
|
||||
# Calculate wall clock metrics
|
||||
###############################
|
||||
|
||||
wall_clock_latency_s = wall_clock_end_time - wall_clock_start_time
|
||||
wall_clock_thrpt = batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)
|
||||
logger.info(f"Wall-Clock Latency: {wall_clock_latency_s} s")
|
||||
logger.info(
|
||||
f"Wall-Clock Throughput: {batch_size * ((prompt_length + args.generation_length) / wall_clock_latency_s)} tps"
|
||||
)
|
||||
|
||||
# Add metrics to CSV
|
||||
logger.info("Adding results to CSV")
|
||||
csv_metrics.extend(
|
||||
[
|
||||
avg_sampling_latency_ms,
|
||||
avg_sampling_thrpt,
|
||||
first_token_latency_ms,
|
||||
first_token_thrpt,
|
||||
halfway_token_latency_ms,
|
||||
halfway_token_thrpt,
|
||||
all_token_latency_ms,
|
||||
all_token_thrpt,
|
||||
wall_clock_latency_s,
|
||||
wall_clock_thrpt,
|
||||
]
|
||||
)
|
||||
all_csv_metrics.append(csv_metrics)
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}")
|
||||
|
||||
filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
|
||||
save_results(all_csv_metrics, filename, args.generation_length)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+1054
File diff suppressed because it is too large
Load Diff
+57
@@ -0,0 +1,57 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def init_dist():
|
||||
if "LOCAL_RANK" in os.environ:
|
||||
int(os.environ["LOCAL_RANK"])
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7645", world_size=world_size, rank=rank)
|
||||
elif "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
|
||||
int(os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
|
||||
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
|
||||
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
|
||||
|
||||
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:7647", world_size=world_size, rank=rank)
|
||||
else:
|
||||
# don't need to do init for single process
|
||||
pass
|
||||
|
||||
|
||||
def _get_comm():
|
||||
try:
|
||||
from mpi4py import MPI # noqa: PLC0415
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
return comm
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def get_rank():
|
||||
comm = _get_comm()
|
||||
return comm.Get_rank() if comm is not None else 0
|
||||
|
||||
|
||||
def get_size():
|
||||
comm = _get_comm()
|
||||
return comm.Get_size() if comm is not None else 1
|
||||
|
||||
|
||||
def barrier():
|
||||
comm = _get_comm()
|
||||
if comm is not None:
|
||||
comm.Barrier()
|
||||
|
||||
|
||||
def print_out(*args):
|
||||
if get_rank() == 0:
|
||||
print(*args)
|
||||
+504
@@ -0,0 +1,504 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
from onnxruntime import InferenceSession, OrtValue
|
||||
|
||||
|
||||
# Get position_ids from attention_mask
|
||||
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if use_past_kv:
|
||||
# Shape: (batch_size, 1)
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# Shape: (batch_size, sequence_length)
|
||||
return position_ids
|
||||
|
||||
|
||||
# Inputs for first pass to get initial past_key_values
|
||||
# input_ids: (batch_size, sequence_length)
|
||||
# attention_mask: (batch_size, sequence_length)
|
||||
# position_ids: (batch_size, sequence_length)
|
||||
def get_sample_inputs(
|
||||
config: AutoConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
engine: str = "pt",
|
||||
return_dict: bool = False,
|
||||
):
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
||||
|
||||
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
||||
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
||||
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
||||
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
||||
|
||||
if not return_dict:
|
||||
# For export
|
||||
return (input_ids, attention_mask, position_ids)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
return inputs
|
||||
|
||||
|
||||
# Inputs for subsequent passes with past_key_values
|
||||
# input_ids: (batch_size, 1)
|
||||
# attention_mask: (batch_size, past_sequence_length + 1)
|
||||
# position_ids: (batch_size, 1)
|
||||
# past_key: (batch_size, num_heads, past_sequence_length, head_size)
|
||||
# past_value: (batch_size, num_heads, past_sequence_length, head_size)
|
||||
def get_sample_with_past_kv_inputs(
|
||||
config: AutoConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
past_seq_len: int,
|
||||
use_fp16: bool = False,
|
||||
engine: str = "pt",
|
||||
return_dict: bool = False,
|
||||
world_size: int = 1,
|
||||
):
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, 1)
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=True)
|
||||
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
|
||||
|
||||
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
||||
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
||||
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
||||
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
||||
past_kv = (
|
||||
flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
# For export
|
||||
assert isinstance(past_kv, list)
|
||||
return (input_ids, attention_mask, position_ids, past_kv)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
if engine == "ort":
|
||||
assert isinstance(past_kv, dict)
|
||||
inputs.update(past_kv)
|
||||
else:
|
||||
assert isinstance(past_kv, list)
|
||||
inputs["past_key_values"] = past_kv
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
# Inputs for all passes with past_key_values
|
||||
# input_ids: (batch_size, sequence_length)
|
||||
# attention_mask: (batch_size, past_sequence_length + sequence_length)
|
||||
# position_ids: (batch_size, sequence_length)
|
||||
# past_key: (batch_size, num_heads, kv_sequence_length, head_size)
|
||||
# For models with GQA, kv_sequence_length = max_sequence_length
|
||||
# For models without GQA, kv_sequence_length = past_sequence_length
|
||||
# past_value: (batch_size, num_heads, kv_sequence_length, head_size)
|
||||
# For models with GQA, kv_sequence_length = max_sequence_length
|
||||
# For models without GQA, kv_sequence_length = past_sequence_length
|
||||
def get_merged_sample_with_past_kv_inputs(
|
||||
config: AutoConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
past_seq_len: int,
|
||||
max_seq_len: int,
|
||||
use_fp16: bool = False,
|
||||
use_buffer_share: bool = False,
|
||||
engine: str = "pt",
|
||||
return_dict: bool = False,
|
||||
world_size: int = 1,
|
||||
):
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
|
||||
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
|
||||
# position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
|
||||
past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)
|
||||
|
||||
# Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
|
||||
input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
|
||||
attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
|
||||
position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
|
||||
past_kv = (
|
||||
flatten_past_kv_inputs(past_kv) if engine == "ort" else [(kv[0].to(device), kv[1].to(device)) for kv in past_kv]
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
# For export
|
||||
assert isinstance(past_kv, list)
|
||||
return (input_ids, attention_mask, position_ids, past_kv)
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
if engine == "ort":
|
||||
assert isinstance(past_kv, dict)
|
||||
inputs.update(past_kv)
|
||||
|
||||
if use_buffer_share:
|
||||
inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)
|
||||
|
||||
else:
|
||||
assert isinstance(past_kv, list)
|
||||
inputs["past_key_values"] = past_kv
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
|
||||
def get_msft_sample_inputs(
|
||||
config: AutoConfig,
|
||||
batch_size: int,
|
||||
past_seq_len: int,
|
||||
seq_len: int,
|
||||
max_seq_len: int,
|
||||
use_fp16: bool,
|
||||
use_buffer_share: bool,
|
||||
split_kv: bool,
|
||||
):
|
||||
np_dtype = np.float16 if use_fp16 else np.float32
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
|
||||
if not split_kv:
|
||||
ort_inputs = {
|
||||
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
|
||||
"attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
|
||||
"k_cache": np.random.rand(
|
||||
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
|
||||
).astype(np_dtype),
|
||||
"v_cache": np.random.rand(
|
||||
batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
|
||||
).astype(np_dtype),
|
||||
"pos": np.array(past_seq_len, dtype=np.int64),
|
||||
}
|
||||
else:
|
||||
ort_inputs = {
|
||||
"x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
|
||||
"attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
|
||||
np.int32
|
||||
),
|
||||
"pos": np.array(past_seq_len, dtype=np.int64),
|
||||
}
|
||||
for i in range(config.num_hidden_layers):
|
||||
ort_inputs.update(
|
||||
{
|
||||
f"k_{i}_cache": np.random.rand(
|
||||
batch_size, config.num_attention_heads, past_seq_len, head_size
|
||||
).astype(np_dtype),
|
||||
f"v_{i}_cache": np.random.rand(
|
||||
batch_size, config.num_attention_heads, past_seq_len, head_size
|
||||
).astype(np_dtype),
|
||||
}
|
||||
)
|
||||
|
||||
if use_buffer_share:
|
||||
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
|
||||
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Create past_key_values
|
||||
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
|
||||
def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
|
||||
num_heads = config.num_key_value_heads // world_size
|
||||
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
past_kv = [
|
||||
(
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
return past_kv
|
||||
|
||||
|
||||
# Convert list of past_key_values to dict of past_key and past_value
|
||||
def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
|
||||
past_kv = {}
|
||||
for i, (past_k, past_v) in enumerate(past_key_values):
|
||||
if isinstance(past_key_values, DynamicCache):
|
||||
past_kv[f"past_key_values_key_cache_{i}"] = past_k.detach().cpu().numpy()
|
||||
past_kv[f"past_key_values_value_cache_{i}"] = past_v.detach().cpu().numpy()
|
||||
else:
|
||||
past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
|
||||
past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
|
||||
return past_kv
|
||||
|
||||
|
||||
# Format PyTorch inputs to ONNX Runtime inputs
|
||||
def convert_inputs_for_ort(
|
||||
pt_inputs: dict,
|
||||
use_buffer_share: bool = False,
|
||||
past_seq_len: int = 0,
|
||||
max_seq_len: int = 2048,
|
||||
):
|
||||
ort_inputs = {}
|
||||
for k, v in pt_inputs.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
ort_inputs[k] = v
|
||||
elif k == "past_key_values":
|
||||
ort_inputs.update(flatten_past_kv_inputs(v))
|
||||
else:
|
||||
ort_inputs[k] = v.detach().cpu().numpy()
|
||||
|
||||
# Reshape KV caches if using past-present-share-buffer
|
||||
if use_buffer_share:
|
||||
ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)
|
||||
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
|
||||
# (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
|
||||
def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
|
||||
for k, v in ort_inputs.items():
|
||||
# Allocate new buffers with max_sequence_length for GQA
|
||||
if "cache" in k or "past_key_values" in k:
|
||||
# Copy v (BxSxPxH) into new_v (BxSxMxH)
|
||||
batch_size, num_heads, _, head_size = v.shape
|
||||
new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
|
||||
new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
|
||||
ort_inputs[k] = new_v
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Verify ONNX Runtime inputs with model
|
||||
def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
|
||||
# Check that all model inputs will be provided
|
||||
model_inputs = {model_input.name for model_input in model.get_inputs()}
|
||||
user_inputs = set(ort_inputs.keys())
|
||||
missing_inputs = model_inputs - user_inputs
|
||||
if len(missing_inputs):
|
||||
print(f"The following model inputs are missing: {missing_inputs}")
|
||||
raise Exception("There are missing inputs to the model. Please add them and try again.")
|
||||
|
||||
# Remove unnecessary inputs from model inputs
|
||||
unnecessary_inputs = user_inputs - model_inputs
|
||||
if len(unnecessary_inputs):
|
||||
for unnecessary_input in unnecessary_inputs:
|
||||
del ort_inputs[unnecessary_input]
|
||||
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Add IO bindings for execution providers using OrtValue
|
||||
# Use when you need to run inference once or twice to save memory
|
||||
def add_io_bindings_as_ortvalues(
|
||||
model: InferenceSession,
|
||||
ort_inputs: dict,
|
||||
device: str,
|
||||
device_id: int,
|
||||
use_buffer_share: bool,
|
||||
kv_cache_ortvalues: dict,
|
||||
):
|
||||
io_binding = model.io_binding()
|
||||
|
||||
model_inputs = {i.name for i in model.get_inputs()}
|
||||
for k, v in ort_inputs.items():
|
||||
# Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
|
||||
# GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
|
||||
# but `position_ids` is used as a PyTorch model input
|
||||
if k not in model_inputs:
|
||||
continue
|
||||
|
||||
# Bind OrtValue inputs to device
|
||||
if use_buffer_share and ("cache" in k or "past_key_values" in k):
|
||||
if k not in kv_cache_ortvalues:
|
||||
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
|
||||
io_binding.bind_ortvalue_input(k, v_device)
|
||||
kv_cache_ortvalues[k] = v_device
|
||||
else:
|
||||
kv_cache_ortvalues[k].update_inplace(v)
|
||||
io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
|
||||
else:
|
||||
v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
|
||||
io_binding.bind_ortvalue_input(k, v_device)
|
||||
|
||||
for output in model.get_outputs():
|
||||
name = output.name
|
||||
if use_buffer_share and ("out" in name or "present" in name):
|
||||
# Bind present KV cache outputs to past KV cache inputs in order to buffer share
|
||||
input_name = name.replace("out", "cache").replace("present", "past_key_values")
|
||||
io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
|
||||
else:
|
||||
io_binding.bind_output(name, device_type=device, device_id=device_id)
|
||||
|
||||
return io_binding, kv_cache_ortvalues
|
||||
|
||||
|
||||
# Add IO bindings for execution providers using PyTorch tensors
|
||||
# Use when you need to run inference many times
|
||||
def add_io_bindings_as_tensors(
|
||||
model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
|
||||
):
|
||||
# Verify model inputs
|
||||
inputs = verify_ort_inputs(model, inputs)
|
||||
|
||||
device = None
|
||||
pt_to_np = {
|
||||
"torch.int32": np.int32,
|
||||
"torch.int64": np.int64,
|
||||
"torch.float16": np.float16,
|
||||
"torch.float32": np.float32,
|
||||
}
|
||||
|
||||
# Bind inputs/outputs to IO binding
|
||||
io_binding = model.io_binding()
|
||||
for k, v in inputs.items():
|
||||
io_binding.bind_input(
|
||||
name=k,
|
||||
device_type=v.device.type,
|
||||
device_id=0 if v.device.type == "cpu" else v.device.index,
|
||||
element_type=pt_to_np[repr(v.dtype)],
|
||||
shape=tuple(v.shape),
|
||||
buffer_ptr=v.data_ptr(),
|
||||
)
|
||||
device = v.device
|
||||
|
||||
for output in model.get_outputs():
|
||||
name = output.name
|
||||
# Bind KV cache outputs to KV cache inputs
|
||||
v = (
|
||||
inputs[name.replace("present", "past_key_values")]
|
||||
if use_buffer_share and "present" in name
|
||||
else outputs[name]
|
||||
)
|
||||
io_binding.bind_output(
|
||||
name=name,
|
||||
device_type=device.type,
|
||||
device_id=0 if device.type == "cpu" else device.index,
|
||||
element_type=(np.float16 if use_fp16 else np.float32),
|
||||
shape=tuple(v.shape),
|
||||
buffer_ptr=v.data_ptr(),
|
||||
)
|
||||
|
||||
return io_binding
|
||||
|
||||
|
||||
# Get actual inputs when using real data (instead of sample data) and initialize outputs
|
||||
def get_initial_inputs_and_outputs(
|
||||
config: AutoConfig,
|
||||
tokenizer: AutoTokenizer,
|
||||
requested_length: int,
|
||||
prompt: list[str],
|
||||
device: torch.device,
|
||||
use_fp16: bool,
|
||||
use_buffer_share: bool,
|
||||
engine: str,
|
||||
):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
|
||||
# input_ids: pad token id is 0
|
||||
# attention_mask: pad token id is 0
|
||||
# position_ids: pad token id is 1
|
||||
input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
|
||||
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
||||
|
||||
# Check if tokenized prompt length matches the requested prompt length
|
||||
tokenized_length = input_ids.shape[-1]
|
||||
if tokenized_length > requested_length:
|
||||
# Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
|
||||
input_ids = input_ids[:, :requested_length]
|
||||
attention_mask = attention_mask[:, :requested_length]
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
||||
elif tokenized_length < requested_length:
|
||||
# Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
|
||||
input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
|
||||
attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
|
||||
for _ in range(requested_length - tokenized_length):
|
||||
input_ids = torch.hstack((input_ids_first_col, input_ids))
|
||||
attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
||||
|
||||
tokenized_length = input_ids.shape[-1]
|
||||
assert tokenized_length == requested_length
|
||||
|
||||
# Create inputs
|
||||
inputs = {
|
||||
"input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
|
||||
"attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
|
||||
"position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
|
||||
}
|
||||
if engine != "ort":
|
||||
inputs["past_key_values"] = []
|
||||
|
||||
# Get shape of KV cache inputs
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
max_sequence_length = config.max_position_embeddings
|
||||
num_heads = config.num_key_value_heads
|
||||
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
|
||||
# Create KV cache inputs
|
||||
for i in range(config.num_hidden_layers):
|
||||
past_key = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_sequence_length if use_buffer_share else 0,
|
||||
head_size,
|
||||
device=device,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
past_value = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
max_sequence_length if use_buffer_share else 0,
|
||||
head_size,
|
||||
device=device,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
if engine == "ort":
|
||||
inputs.update(
|
||||
{
|
||||
f"past_key_values.{i}.key": past_key.contiguous(),
|
||||
f"past_key_values.{i}.value": past_value.contiguous(),
|
||||
}
|
||||
)
|
||||
else:
|
||||
inputs["past_key_values"].append((past_key, past_value))
|
||||
|
||||
outputs = None
|
||||
if engine == "ort":
|
||||
# Create outputs
|
||||
logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
|
||||
outputs = {"logits": logits.contiguous()}
|
||||
if not use_buffer_share:
|
||||
for i in range(config.num_hidden_layers):
|
||||
present_key = torch.zeros(
|
||||
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
|
||||
)
|
||||
present_value = torch.zeros(
|
||||
batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
|
||||
)
|
||||
outputs.update(
|
||||
{f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
|
||||
)
|
||||
|
||||
return inputs, outputs
|
||||
+343
@@ -0,0 +1,343 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import packaging.version as pv
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from dist_settings import get_rank, get_size
|
||||
from llama_inputs import (
|
||||
add_io_bindings_as_ortvalues,
|
||||
convert_inputs_for_ort,
|
||||
get_merged_sample_with_past_kv_inputs,
|
||||
get_sample_inputs,
|
||||
get_sample_with_past_kv_inputs,
|
||||
verify_ort_inputs,
|
||||
)
|
||||
from llama_torch import setup_torch_model
|
||||
from models.torch_export_patches.cache_helper import make_dynamic_cache
|
||||
from transformers import AutoConfig
|
||||
from transformers import __version__ as transformers_version
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
|
||||
past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
|
||||
max_sequence_length = config.max_position_embeddings
|
||||
return past_sequence_length, curr_sequence_length, max_sequence_length
|
||||
|
||||
|
||||
def get_inputs(args: argparse.Namespace, config: AutoConfig):
|
||||
# Dummy values for parity
|
||||
world_size = get_size()
|
||||
batch_size = 2
|
||||
past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)
|
||||
|
||||
if args.merged:
|
||||
inputs = get_merged_sample_with_past_kv_inputs(
|
||||
config,
|
||||
args.device,
|
||||
batch_size,
|
||||
seq_len=sequence_length,
|
||||
past_seq_len=past_sequence_length,
|
||||
max_seq_len=max_sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
return_dict=True,
|
||||
world_size=world_size,
|
||||
)
|
||||
elif args.use_past_kv:
|
||||
inputs = get_sample_with_past_kv_inputs(
|
||||
config,
|
||||
args.device,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
use_fp16=args.use_fp16,
|
||||
return_dict=True,
|
||||
world_size=world_size,
|
||||
)
|
||||
else:
|
||||
inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def torch_deepcopy(value):
|
||||
if isinstance(value, (int, float, str)):
|
||||
return value
|
||||
if isinstance(value, tuple):
|
||||
return tuple(torch_deepcopy(v) for v in value)
|
||||
if isinstance(value, list):
|
||||
return [torch_deepcopy(v) for v in value]
|
||||
if isinstance(value, set):
|
||||
return {torch_deepcopy(v) for v in value}
|
||||
if isinstance(value, dict):
|
||||
return {k: torch_deepcopy(v) for k, v in value.items()}
|
||||
if isinstance(value, np.ndarray):
|
||||
return value.copy()
|
||||
if hasattr(value, "clone"):
|
||||
return value.clone()
|
||||
if isinstance(value, DynamicCache):
|
||||
return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False))))
|
||||
# We should have a code using serialization, deserialization assuming a model
|
||||
# cannot be exported without them.
|
||||
raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
|
||||
|
||||
|
||||
def verify_parity(
|
||||
args: argparse.Namespace,
|
||||
location: str,
|
||||
use_auth_token: bool,
|
||||
kv_cache_ortvalues: dict,
|
||||
pytorch_model: None | torch.nn.Module = None,
|
||||
config: None | AutoConfig = None,
|
||||
):
|
||||
# If it's running in a machine where GPU memory < 36GB, it should unload the model in GPU in time and free the GPU memory for ORT.
|
||||
py_model = pytorch_model
|
||||
if py_model is None:
|
||||
config, py_model = setup_torch_model(
|
||||
args,
|
||||
location,
|
||||
use_auth_token,
|
||||
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
inputs = get_inputs(args, config)
|
||||
|
||||
if "past_key_values" in inputs and pv.Version(transformers_version) >= pv.Version("4.45"):
|
||||
# Using DynamicCache
|
||||
inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"])
|
||||
|
||||
# Run inference with PyTorch
|
||||
inputs_after_deepcopy = torch_deepcopy(inputs)
|
||||
if args.execution_provider != "cpu":
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
# If there is a cache in the inputs, we need to make a copy as the model modifies them inplace.
|
||||
# DynamicCache inherits from torch.nn.Module in some version of transformers.
|
||||
# We need to make the copy manually.
|
||||
pt_outputs = py_model(**inputs_after_deepcopy).logits.detach().cpu().numpy()
|
||||
if args.execution_provider != "cpu":
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
logger.info(f"PyTorch took {end_time - start_time} s")
|
||||
|
||||
if args.small_gpu and py_model is not None:
|
||||
del py_model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Run inference with ORT
|
||||
past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config)
|
||||
inputs = convert_inputs_for_ort(
|
||||
inputs,
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
past_seq_len=past_sequence_length,
|
||||
max_seq_len=max_sequence_length,
|
||||
)
|
||||
|
||||
ep = f"{args.execution_provider.upper()}ExecutionProvider"
|
||||
if ep == "CUDAExecutionProvider":
|
||||
ep = (ep, {"device_id": args.rank})
|
||||
ort_model = ort.InferenceSession(
|
||||
args.onnx_model_path,
|
||||
sess_options=ort.SessionOptions(),
|
||||
providers=[ep],
|
||||
)
|
||||
inputs = verify_ort_inputs(ort_model, inputs)
|
||||
|
||||
# Add IO bindings for non-CPU execution providers
|
||||
if args.execution_provider != "cpu":
|
||||
io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
|
||||
ort_model,
|
||||
ort_inputs=inputs,
|
||||
device=args.execution_provider,
|
||||
device_id=int(args.rank),
|
||||
use_buffer_share=args.use_buffer_share,
|
||||
kv_cache_ortvalues=kv_cache_ortvalues,
|
||||
)
|
||||
|
||||
io_binding.synchronize_inputs()
|
||||
start_time = time.time()
|
||||
ort_model.run_with_iobinding(io_binding)
|
||||
io_binding.synchronize_outputs()
|
||||
end_time = time.time()
|
||||
|
||||
ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
|
||||
del ort_model
|
||||
|
||||
else:
|
||||
start_time = time.time()
|
||||
ort_outputs = ort_model.run(None, inputs)
|
||||
end_time = time.time()
|
||||
|
||||
ort_outputs = ort_outputs[0] # Get logits
|
||||
|
||||
logger.info(f"ONNX Runtime took {end_time - start_time} s")
|
||||
|
||||
# Compare PyTorch and ONNX Runtime accuracy
|
||||
tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
|
||||
parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
|
||||
logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
|
||||
if not parity:
|
||||
logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
|
||||
return kv_cache_ortvalues
|
||||
|
||||
|
||||
def get_args(argv: list[str]):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name",
|
||||
required=False,
|
||||
help="Model name in Hugging Face",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--torch_model_directory",
|
||||
required=False,
|
||||
default=os.path.join("."),
|
||||
help="Path to folder containing PyTorch model and associated files if saved on disk",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--onnx_model_path",
|
||||
required=True,
|
||||
default=os.path.join("."),
|
||||
help="Path to ONNX model (with external data files saved in the same folder as the model)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-ep",
|
||||
"--execution_provider",
|
||||
required=False,
|
||||
default="cpu",
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
help="Execution provider to verify parity with",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Print verbose logs",
|
||||
)
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--use_past_kv",
|
||||
action="store_true",
|
||||
help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
|
||||
)
|
||||
parser.set_defaults(use_past_kv=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--use_buffer_share",
|
||||
action="store_true",
|
||||
help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing",
|
||||
)
|
||||
parser.set_defaults(use_buffer_share=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--merged",
|
||||
action="store_true",
|
||||
help="Use merged model (i.e. decoder_merged_model.onnx).",
|
||||
)
|
||||
parser.set_defaults(merged=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-fp",
|
||||
"--precision",
|
||||
required=True,
|
||||
choices=["int4", "int8", "fp16", "fp32"],
|
||||
help="Precision of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./model_cache",
|
||||
help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
|
||||
)
|
||||
|
||||
# The argument is used for CI mainly, because the CI machine has 24G GPU memory at most.
|
||||
parser.add_argument(
|
||||
"--small_gpu",
|
||||
action="store_true",
|
||||
help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ",
|
||||
)
|
||||
|
||||
args = parser.parse_args() if argv == [] else parser.parse_args(argv)
|
||||
|
||||
# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
|
||||
args.precision = (
|
||||
"fp32"
|
||||
if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu")
|
||||
else "fp16"
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
def main(argv: list[str] = []): # noqa: B006
|
||||
args = get_args(argv)
|
||||
setup_logger(args.verbose)
|
||||
logger.info(f"Arguments: {args}")
|
||||
rank = get_rank()
|
||||
|
||||
# Load model and config
|
||||
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
|
||||
args.rank = rank
|
||||
setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
|
||||
setattr(args, "device", torch.device(args.device_name)) # noqa: B010
|
||||
use_auth_token = args.torch_model_directory == os.path.join(".")
|
||||
location = args.model_name if use_auth_token else args.torch_model_directory
|
||||
|
||||
kv_cache_ortvalues = {}
|
||||
if not args.merged:
|
||||
verify_parity(args, location, use_auth_token, kv_cache_ortvalues)
|
||||
else:
|
||||
config = llama = None
|
||||
if not args.small_gpu:
|
||||
config, llama = setup_torch_model(
|
||||
args,
|
||||
location,
|
||||
use_auth_token,
|
||||
torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
# Verify prompt processing in merged model (decoder_model.onnx)
|
||||
args.use_past_kv = False
|
||||
kv_cache_ortvalues = verify_parity(
|
||||
args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config
|
||||
)
|
||||
|
||||
# Verify token generation in merged model (decoder_with_past_model.onnx)
|
||||
args.use_past_kv = True
|
||||
verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed = 2
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
main()
|
||||
+47
@@ -0,0 +1,47 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from dist_settings import barrier, get_rank, get_size
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=None):
|
||||
world_size = get_size()
|
||||
logger.info(f"world_size: {world_size}")
|
||||
rank = get_rank()
|
||||
barrier()
|
||||
|
||||
if not os.path.exists(args.cache_dir):
|
||||
os.makedirs(args.cache_dir, exist_ok=True)
|
||||
|
||||
for i in range(world_size):
|
||||
if i == rank % (world_size):
|
||||
l_config = AutoConfig.from_pretrained(
|
||||
location, use_auth_token=auth, cache_dir=args.cache_dir, trust_remote_code=auth
|
||||
)
|
||||
l_config.use_cache = True
|
||||
l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
|
||||
llama = AutoModelForCausalLM.from_pretrained(
|
||||
location,
|
||||
use_auth_token=auth,
|
||||
trust_remote_code=auth,
|
||||
config=l_config,
|
||||
torch_dtype=torch_dtype,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
if world_size > 1:
|
||||
llama.parallel_model()
|
||||
if device:
|
||||
llama.to(device)
|
||||
llama.eval()
|
||||
llama.requires_grad_(False)
|
||||
barrier()
|
||||
return l_config, llama
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from benchmark_helper import create_onnxruntime_session
|
||||
from datasets import load_dataset
|
||||
from llama_inputs import get_position_ids
|
||||
from torch.nn.functional import pad
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
|
||||
class QuantKVDataLoader:
|
||||
def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""):
|
||||
self.batch_size = 1
|
||||
self.pad_max = args.pad_max
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token)
|
||||
dataset = load_dataset(args.smooth_quant_dataset, split="train")
|
||||
dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
|
||||
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
||||
|
||||
self.dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=self.collate_batch,
|
||||
)
|
||||
self.decoder_model = (
|
||||
create_onnxruntime_session(
|
||||
onnx_model_path,
|
||||
args.execution_provider != "cpu", # use_gpu
|
||||
provider=args.execution_provider,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
if onnx_model_path
|
||||
else None
|
||||
)
|
||||
|
||||
def collate_batch(self, batch):
|
||||
input_ids_batched = []
|
||||
attention_mask_batched = []
|
||||
position_ids_batched = []
|
||||
labels = []
|
||||
|
||||
for text in batch:
|
||||
# Set inputs for model
|
||||
input_ids = text["input_ids"]
|
||||
attention_mask = torch.ones(len(input_ids))
|
||||
position_ids = get_position_ids(attention_mask, use_past_kv=False)
|
||||
label = len(input_ids) - 1
|
||||
|
||||
# Pad input data because all model inputs must have same shape
|
||||
pad_len = self.pad_max - input_ids.shape[0]
|
||||
input_ids = pad(input_ids, (0, pad_len), value=1)
|
||||
attention_mask = pad(attention_mask, (0, pad_len), value=0)
|
||||
position_ids = pad(position_ids, (0, pad_len), value=0)
|
||||
|
||||
input_ids_batched.append(input_ids)
|
||||
attention_mask_batched.append(attention_mask)
|
||||
position_ids_batched.append(position_ids)
|
||||
labels.append(label)
|
||||
|
||||
input_ids_batched = torch.vstack(input_ids_batched)
|
||||
attention_mask_batched = torch.vstack(attention_mask_batched)
|
||||
position_ids_batched = torch.vstack(position_ids_batched)
|
||||
labels = torch.tensor(labels)
|
||||
|
||||
return (input_ids_batched, attention_mask_batched, position_ids_batched), labels
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
for (input_ids, attention_mask, position_ids), labels in self.dataloader:
|
||||
# Inputs for decoder_model.onnx
|
||||
inputs = {
|
||||
"input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
|
||||
"attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64),
|
||||
"position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
|
||||
}
|
||||
label = labels.detach().cpu().numpy()
|
||||
|
||||
if self.decoder_model is not None:
|
||||
# Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx
|
||||
outputs = self.decoder_model.run(None, inputs)
|
||||
|
||||
for i in range(int((len(outputs) - 1) / 2)):
|
||||
inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1]
|
||||
inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2]
|
||||
past_sequence_length = inputs["past_key_values.0.key"].shape[2]
|
||||
|
||||
inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64)
|
||||
attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64)
|
||||
inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64)
|
||||
inputs["position_ids"] = (
|
||||
get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64)
|
||||
)
|
||||
|
||||
# Yield (inputs, label) tuple for Intel's Neural Compressor:
|
||||
# https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62
|
||||
yield (inputs, label)
|
||||
|
||||
except StopIteration:
|
||||
return
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+821
@@ -0,0 +1,821 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# This script run benchmark of latency or peak memory usage of Longformer model inference.
|
||||
# Please run convert_to_onnx.py to get onnx model before running benchmark.
|
||||
#
|
||||
# It is tested with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.11.0, transformers 4.18.0, CUDA 11.3 like:
|
||||
# conda create -n gpu_env python=3.8
|
||||
# conda activate gpu_env
|
||||
# pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
# pip3 install onnx transformers onnxruntime-gpu numpy sympy coloredlogs psutil py3nvml
|
||||
# python benchmark_longformer.py
|
||||
#
|
||||
# When there is no parameter, pre-defined tests will run on the longformer-base-4096 model.
|
||||
|
||||
# Benchmark the latency:
|
||||
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 512 1024 2048 4096 \
|
||||
# --global_lengths 8 --onnx ./longformer-base-4096_fp16.onnx -t 100
|
||||
#
|
||||
# Benchmark GPU peak memory:
|
||||
# export ORT_LONGFORMER_COMPACT_MEMORY=0
|
||||
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
|
||||
# --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
|
||||
# export ORT_LONGFORMER_COMPACT_MEMORY=1
|
||||
# python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
|
||||
# --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
|
||||
#
|
||||
# By default, compact memory kernel is enabled. To disable it, set environment variable ORT_LONGFORMER_COMPACT_MEMORY=0.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import timeit
|
||||
import traceback
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import benchmark_helper
|
||||
import numpy as np
|
||||
import torch
|
||||
from longformer_helper import PRETRAINED_LONGFORMER_MODELS, LongformerHelper, LongformerInputs
|
||||
from transformers import LongformerModel
|
||||
|
||||
import onnxruntime
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def test_torch_latency(
|
||||
device,
|
||||
model,
|
||||
model_name,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
global_lengths,
|
||||
test_times,
|
||||
num_threads,
|
||||
) -> list[dict[str, Any]]:
|
||||
if num_threads > 0:
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
results = []
|
||||
for batch_size in batch_sizes:
|
||||
for sequence_length in sequence_lengths:
|
||||
for global_length in global_lengths:
|
||||
logger.info(f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}")
|
||||
inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
|
||||
batch_size, sequence_length, global_length, device
|
||||
)
|
||||
input_list = inputs.to_list()
|
||||
|
||||
_ = model(*input_list)
|
||||
runtimes = timeit.repeat(lambda: model(*input_list), repeat=test_times, number=1) # noqa: B023
|
||||
result = {
|
||||
"engine": "torch", # TODO: test torchscript
|
||||
"version": torch.__version__,
|
||||
"device": "cuda",
|
||||
"optimizer": "",
|
||||
"precision": "fp32",
|
||||
"io_binding": "",
|
||||
"model_name": model_name,
|
||||
"description": model_name + " [torch]",
|
||||
"inputs": 3,
|
||||
"threads": num_threads,
|
||||
"batch_size": batch_size,
|
||||
"sequence_length": sequence_length,
|
||||
"global_length": global_length,
|
||||
"datetime": str(datetime.now()),
|
||||
"memory": "NA",
|
||||
"diff_max": 0,
|
||||
"diff_90_percentile": 0,
|
||||
"diff_95_percentile": 0,
|
||||
"diff_99_percentile": 0,
|
||||
"use_compact_memory": "NA",
|
||||
}
|
||||
result.update(benchmark_helper.get_latency_result(runtimes, batch_size))
|
||||
logger.info("%s", result)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def test_parity(device, model, ort_session, batch_size, sequence_length, global_length, verbose=True):
|
||||
parameters = f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}"
|
||||
logger.info(f"Comparing Torch and ORT outputs for {parameters}...")
|
||||
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
|
||||
batch_size, sequence_length, global_length, device
|
||||
)
|
||||
ort_inputs = dummy_inputs.get_ort_inputs()
|
||||
ort_outputs = ort_session.run(None, ort_inputs)
|
||||
input_list = dummy_inputs.to_list()
|
||||
torch_outputs = model(*input_list)
|
||||
max_diff = np.amax(torch_outputs[0].cpu().numpy() - ort_outputs[0])
|
||||
logger.info(f"last_state max diff = {max_diff}")
|
||||
if verbose and (math.isnan(max_diff) or max_diff > 0.001):
|
||||
print("torch last_state:", torch_outputs[0])
|
||||
print("ort last_state:", ort_outputs[0])
|
||||
return float(max_diff)
|
||||
|
||||
|
||||
def test_ort_latency(
|
||||
device,
|
||||
model,
|
||||
model_name,
|
||||
description,
|
||||
ort_session,
|
||||
batch_sizes,
|
||||
sequence_lengths,
|
||||
global_lengths,
|
||||
test_times,
|
||||
num_threads,
|
||||
optimizer=False,
|
||||
precision="fp32",
|
||||
disable_io_binding=False,
|
||||
verbose=True,
|
||||
use_compact_memory=False,
|
||||
use_half4=False,
|
||||
disable_parity=False,
|
||||
) -> list[dict[str, Any]]:
|
||||
results = []
|
||||
for batch_size in batch_sizes:
|
||||
for sequence_length in sequence_lengths:
|
||||
for global_length in global_lengths:
|
||||
assert global_length <= model.config.attention_window[0], (
|
||||
"Limitation of current implementation: number of global token <= attention_window"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Testing batch_size={batch_size} sequence_length={sequence_length} global_length={global_length} "
|
||||
f"optimizer={optimizer}, precision={precision} io_binding={not disable_io_binding}..."
|
||||
)
|
||||
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
|
||||
batch_size, sequence_length, global_length, device
|
||||
)
|
||||
|
||||
# Run OnnxRuntime
|
||||
ort_inputs = dummy_inputs.get_ort_inputs()
|
||||
|
||||
if verbose:
|
||||
print(ort_inputs)
|
||||
|
||||
# run one query for warm up
|
||||
ort_outputs = ort_session.run(None, ort_inputs)
|
||||
|
||||
result_template = {
|
||||
"model_name": model_name,
|
||||
"description": description,
|
||||
"inputs": 3,
|
||||
"engine": "OnnxRuntime",
|
||||
"version": str(onnxruntime.__version__),
|
||||
"device": "cuda",
|
||||
"precision": str(precision),
|
||||
"optimizer": int(optimizer),
|
||||
"threads": int(num_threads),
|
||||
"batch_size": int(batch_size),
|
||||
"sequence_length": int(sequence_length),
|
||||
"global_length": int(global_length),
|
||||
"test_times": int(test_times),
|
||||
"datetime": str(datetime.now()),
|
||||
"memory": "",
|
||||
"diff_max": None,
|
||||
"diff_90_percentile": None,
|
||||
"diff_95_percentile": None,
|
||||
"diff_99_percentile": None,
|
||||
"use_compact_memory": use_compact_memory,
|
||||
"use_half4": use_half4,
|
||||
}
|
||||
|
||||
if not disable_io_binding:
|
||||
max_last_state_size = max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size
|
||||
max_pooler_size = max(batch_sizes) * max(sequence_lengths)
|
||||
result = benchmark_helper.inference_ort_with_io_binding(
|
||||
ort_session,
|
||||
ort_inputs,
|
||||
result_template=result_template,
|
||||
repeat_times=test_times,
|
||||
ort_output_names=["last_state", "pooler"],
|
||||
ort_outputs=ort_outputs,
|
||||
output_buffers=[],
|
||||
output_buffer_max_sizes=[max_last_state_size, max_pooler_size],
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
data_type=np.longlong, # input data type
|
||||
)
|
||||
else:
|
||||
result = benchmark_helper.inference_ort(
|
||||
ort_session,
|
||||
ort_inputs,
|
||||
result_template=result_template,
|
||||
repeat_times=test_times,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
# measure result difference between PyTorch and OnnxRuntime
|
||||
if not disable_parity:
|
||||
diff_results = [
|
||||
test_parity(
|
||||
device,
|
||||
model,
|
||||
ort_session,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
global_length,
|
||||
verbose,
|
||||
)
|
||||
for _ in range(test_times)
|
||||
]
|
||||
|
||||
result["diff_max"] = max(diff_results)
|
||||
result["diff_90_percentile"] = np.percentile(diff_results, 90)
|
||||
result["diff_95_percentile"] = np.percentile(diff_results, 95)
|
||||
result["diff_99_percentile"] = np.percentile(diff_results, 99)
|
||||
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def test_ort_memory(
|
||||
device,
|
||||
onnx_model_path,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
global_length,
|
||||
test_times,
|
||||
num_threads,
|
||||
) -> dict[str, Any]:
|
||||
logger.info(
|
||||
f"Testing memory for model={onnx_model_path}, batch_size={batch_size}, sequence_length={sequence_length}, "
|
||||
f"global_length={global_length}, test_times={test_times}, num_threads={num_threads}"
|
||||
)
|
||||
|
||||
def inference():
|
||||
# Update Arena strategy so that we can measure the minimum memory required
|
||||
cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
|
||||
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
|
||||
session = benchmark_helper.create_onnxruntime_session(
|
||||
onnx_model_path,
|
||||
use_gpu=True,
|
||||
enable_all_optimization=True,
|
||||
num_threads=num_threads,
|
||||
provider_options=provider_options,
|
||||
)
|
||||
|
||||
dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
|
||||
batch_size, sequence_length, global_length, device
|
||||
)
|
||||
ort_inputs = dummy_inputs.get_ort_inputs()
|
||||
for _ in range(test_times):
|
||||
_ = session.run(None, ort_inputs)
|
||||
|
||||
memory_used = benchmark_helper.measure_memory(is_gpu=True, func=inference)
|
||||
|
||||
return {
|
||||
"onnx_model": onnx_model_path,
|
||||
"batch_size": batch_size,
|
||||
"sequence_length": sequence_length,
|
||||
"global_length": global_length,
|
||||
"test_times": test_times,
|
||||
"num_threads": num_threads,
|
||||
"memory": memory_used,
|
||||
}
|
||||
|
||||
|
||||
def load_torch_model(model_name, device):
|
||||
torch_model_name_or_dir = PRETRAINED_LONGFORMER_MODELS.get(model_name, model_name)
|
||||
model = LongformerModel.from_pretrained(torch_model_name_or_dir)
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
def find_onnx_model(model_name, onnx_dir="."):
|
||||
# Search onnx model in the following order: optimized fp16 model, optimized fp32 model, raw model
|
||||
onnx_model_path = os.path.join(onnx_dir, model_name + ".onnx")
|
||||
optimized_fp32_model = os.path.join(onnx_dir, model_name + "_fp32.onnx")
|
||||
optimized_fp16_model = os.path.join(onnx_dir, model_name + "_fp16.onnx")
|
||||
if os.path.isfile(optimized_fp16_model):
|
||||
onnx_model_path = optimized_fp16_model
|
||||
elif os.path.isfile(optimized_fp32_model):
|
||||
onnx_model_path = optimized_fp32_model
|
||||
return onnx_model_path
|
||||
|
||||
|
||||
def test_memory(args, device) -> dict[str, Any]:
|
||||
if len(args.batch_sizes) > 1:
|
||||
raise RuntimeError("For memory test, only one batch_size (-b) is allowed.")
|
||||
if len(args.sequence_lengths) > 1:
|
||||
raise RuntimeError("For memory test, only one sequence_length (-s) is allowed.")
|
||||
if len(args.global_lengths) > 1:
|
||||
raise RuntimeError("For memory test, only one global_length (-g) is allowed.")
|
||||
|
||||
model_name = args.model
|
||||
onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
return test_ort_memory(
|
||||
device,
|
||||
onnx_model_path,
|
||||
args.batch_sizes[0],
|
||||
args.sequence_lengths[0],
|
||||
args.global_lengths[0],
|
||||
args.test_times,
|
||||
args.num_threads,
|
||||
)
|
||||
|
||||
|
||||
def test_ort(args, device) -> list[dict[str, Any]]:
|
||||
model_name = args.model
|
||||
|
||||
onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
|
||||
|
||||
optimized = onnx_model_path.endswith("_fp16.onnx") or onnx_model_path.endswith("_fp32.onnx") # noqa: PIE810
|
||||
precision = "fp32" if not onnx_model_path.endswith("_fp16.onnx") else "fp16"
|
||||
|
||||
model = load_torch_model(model_name, device)
|
||||
|
||||
num_threads = args.num_threads
|
||||
|
||||
cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
|
||||
provider_options = {"CUDAExecutionProvider": cuda_provider_options}
|
||||
session = benchmark_helper.create_onnxruntime_session(
|
||||
onnx_model_path,
|
||||
use_gpu=True,
|
||||
enable_all_optimization=True,
|
||||
num_threads=num_threads,
|
||||
provider_options=provider_options,
|
||||
)
|
||||
if session is None:
|
||||
raise RuntimeError(f"Failed to create ORT session from ONNX file {onnx_model_path}")
|
||||
|
||||
use_compact_memory = os.environ.get("ORT_LONGFORMER_COMPACT_MEMORY", "1") == "1"
|
||||
description = onnx_model_path
|
||||
if not use_compact_memory:
|
||||
description += "[non_compact_memory]"
|
||||
|
||||
if args.use_half4:
|
||||
description += "[half4]" if precision == "fp16" else "[float4]"
|
||||
else:
|
||||
description += "[half2]" if precision == "fp16" else "[float4]"
|
||||
|
||||
return test_ort_latency(
|
||||
device,
|
||||
model,
|
||||
model_name,
|
||||
description,
|
||||
session,
|
||||
args.batch_sizes,
|
||||
args.sequence_lengths,
|
||||
args.global_lengths,
|
||||
args.test_times,
|
||||
num_threads,
|
||||
optimized,
|
||||
precision,
|
||||
args.disable_io_binding,
|
||||
args.verbose,
|
||||
use_compact_memory,
|
||||
args.use_half4,
|
||||
args.disable_parity,
|
||||
)
|
||||
|
||||
|
||||
def test_torch(args, device) -> list[dict[str, Any]]:
|
||||
model = load_torch_model(args.model, device)
|
||||
return test_torch_latency(
|
||||
device,
|
||||
model,
|
||||
args.model,
|
||||
args.batch_sizes,
|
||||
args.sequence_lengths,
|
||||
args.global_lengths,
|
||||
args.test_times,
|
||||
args.num_threads,
|
||||
)
|
||||
|
||||
|
||||
def test_latency(args, device) -> list[dict[str, Any]]:
|
||||
if args.engine == "onnxruntime":
|
||||
return test_ort(args, device)
|
||||
|
||||
return test_torch(args, device)
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
required=False,
|
||||
type=str,
|
||||
default="longformer-base-4096",
|
||||
help="Checkpoint directory or pre-trained model names in the list: "
|
||||
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--engine",
|
||||
required=False,
|
||||
type=str,
|
||||
default="onnxruntime",
|
||||
choices=["onnxruntime", "torch"],
|
||||
help="Engine to benchmark.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--test_times",
|
||||
required=False,
|
||||
default=1000,
|
||||
type=int,
|
||||
help="Number of repeat times to get average inference latency.",
|
||||
)
|
||||
|
||||
parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1])
|
||||
|
||||
# If --export_padding is not used in exporting onnx model, there is no padding in ONNX model,
|
||||
# and you will need padding inputs by yourself before running onnx model.
|
||||
# Here, we only test sequence length that is multiple of attention window size.
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--sequence_lengths",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[512, 1024, 2048, 4096],
|
||||
help="Sequence lengths. It could have multiple values in latency test."
|
||||
"If --export_padding is not used, sequence length shall be multiple of window size.",
|
||||
)
|
||||
|
||||
parser.add_argument("--onnx", required=False, type=str, default=None, help="Onnx model path")
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--global_lengths",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[0],
|
||||
help="Number of global tokens. It could have multiple values in latency test.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--num_threads",
|
||||
required=False,
|
||||
type=int,
|
||||
default=0,
|
||||
help="Threads to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable_io_binding",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Do not use IO Binding.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--memory",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Test memory usage instead of latency.",
|
||||
)
|
||||
|
||||
parser.add_argument("--verbose", required=False, action="store_true", help="Print more information.")
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument("--use_half4", required=False, action="store_true", help="Use half4 kernel.")
|
||||
parser.set_defaults(use_half4=False)
|
||||
|
||||
parser.add_argument("--disable_parity", required=False, action="store_true", help="Do not run parity test.")
|
||||
parser.set_defaults(disable_parity=False)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def output_details(results, csv_filename):
|
||||
latency_results = [result for result in results if "average_latency_ms" in result]
|
||||
if len(latency_results) == 0:
|
||||
print("No latency results for output.")
|
||||
return
|
||||
|
||||
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
||||
column_names = [
|
||||
"engine",
|
||||
"version",
|
||||
"device",
|
||||
"precision",
|
||||
"optimizer",
|
||||
"io_binding",
|
||||
"model_name",
|
||||
"inputs",
|
||||
"threads",
|
||||
"datetime",
|
||||
"test_times",
|
||||
"description",
|
||||
"batch_size",
|
||||
"sequence_length",
|
||||
"global_length",
|
||||
"use_compact_memory",
|
||||
"use_half4",
|
||||
"diff_max",
|
||||
"diff_90_percentile",
|
||||
"diff_95_percentile",
|
||||
"diff_99_percentile",
|
||||
"memory",
|
||||
"QPS",
|
||||
"average_latency_ms",
|
||||
"latency_variance",
|
||||
"latency_90_percentile",
|
||||
"latency_95_percentile",
|
||||
"latency_99_percentile",
|
||||
]
|
||||
|
||||
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
||||
csv_writer.writeheader()
|
||||
for result in latency_results:
|
||||
print(result)
|
||||
csv_writer.writerow(result)
|
||||
|
||||
csv_file.flush()
|
||||
|
||||
print(f"Detail results are saved to csv file: {csv_filename}")
|
||||
|
||||
|
||||
def run(args) -> list[dict[str, Any]]:
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# set random seed manually to get deterministic results
|
||||
benchmark_helper.set_random_seed(123)
|
||||
|
||||
# Currently, the longformer attention operator could only run in GPU (no CPU implementation yet).
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
if args.memory:
|
||||
return [test_memory(args, device)] # Convert to List so that return type is same as test_latency
|
||||
|
||||
return test_latency(args, device)
|
||||
|
||||
|
||||
def launch_test(arguments) -> list[dict[str, Any]]:
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
|
||||
|
||||
with ProcessPoolExecutor() as executor:
|
||||
results = list(executor.map(run, [arguments]))
|
||||
assert len(results) == 1
|
||||
return results[0]
|
||||
|
||||
|
||||
def run_tests(
|
||||
use_compact_memory=True,
|
||||
run_torch=False,
|
||||
run_memory=True,
|
||||
use_io_binding=True,
|
||||
use_fp16=True,
|
||||
use_merged_qkv_weights=True,
|
||||
use_half4=True,
|
||||
batch_size=1,
|
||||
):
|
||||
compact_memory = "1" if use_compact_memory else "0"
|
||||
os.environ["ORT_LONGFORMER_COMPACT_MEMORY"] = compact_memory
|
||||
logger.info(f"ORT_LONGFORMER_COMPACT_MEMORY={compact_memory}")
|
||||
|
||||
os.environ["ORT_LONGFORMER_USE_HALF4"] = "1" if use_half4 else "0"
|
||||
logger.info("ORT_LONGFORMER_USE_HALF4={}".format("1" if use_half4 else "0")) # noqa: G001
|
||||
|
||||
results = []
|
||||
test_times = 1000
|
||||
sequence_lengths = [4096, 2048, 1024, 512]
|
||||
batch_sizes = [batch_size]
|
||||
for model_name in ["longformer-base-4096"]:
|
||||
for batch_size in batch_sizes:
|
||||
for sequence_length in sequence_lengths:
|
||||
for global_length in [16]:
|
||||
if run_torch:
|
||||
engine_name = "torch"
|
||||
args = parse_arguments(
|
||||
f"-e {engine_name} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} "
|
||||
f"-t {test_times} -m {model_name}".split(" ")
|
||||
)
|
||||
results += run(args)
|
||||
|
||||
engine_name = "onnxruntime"
|
||||
file_format = 1 if use_merged_qkv_weights else 0
|
||||
onnx_path = (
|
||||
f"{model_name}_f{file_format}_fp16.onnx"
|
||||
if use_fp16
|
||||
else f"{model_name}_f{file_format}_fp32.onnx"
|
||||
)
|
||||
if not os.path.exists(onnx_path):
|
||||
raise RuntimeError(f"onnx file not exists:{onnx_path}")
|
||||
|
||||
arguments = (
|
||||
f"-e {engine_name} --onnx {onnx_path} "
|
||||
f"-b {batch_size} -s {sequence_length} -g {global_length} -m {model_name}"
|
||||
)
|
||||
|
||||
if not use_io_binding:
|
||||
arguments += " --disable_io_binding"
|
||||
|
||||
if use_half4:
|
||||
arguments += " --use_half4"
|
||||
|
||||
# Disable parity test to avoid out of memory for large batch size
|
||||
if batch_size >= 4:
|
||||
arguments += " --disable_parity"
|
||||
|
||||
memory_results = None
|
||||
try:
|
||||
if run_memory:
|
||||
args = parse_arguments(f"{arguments} -t 10 --memory".split(" "))
|
||||
memory_results = launch_test(args)
|
||||
|
||||
args = parse_arguments(f"{arguments} -t {test_times}".split(" "))
|
||||
latency_results = launch_test(args)
|
||||
except KeyboardInterrupt as exc:
|
||||
raise RuntimeError("Keyboard Interrupted") from exc
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
if len(latency_results) == 1:
|
||||
latency_results[0]["memory"] = memory_results[0]["memory"] if memory_results else "N/A"
|
||||
else:
|
||||
raise RuntimeError("length of latency_results should be 1")
|
||||
|
||||
logger.info("%s", latency_results)
|
||||
|
||||
results += latency_results
|
||||
return results
|
||||
|
||||
|
||||
def output_summary(results, csv_filename, data_field="average_latency_ms"):
|
||||
with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
|
||||
header_names = [
|
||||
"model_name",
|
||||
"precision",
|
||||
"engine",
|
||||
"version",
|
||||
"global_length",
|
||||
"use_compact_memory",
|
||||
"use_half4",
|
||||
"description",
|
||||
]
|
||||
|
||||
description_list = list({result["description"] for result in results})
|
||||
description_list.sort()
|
||||
|
||||
batch_sizes = list({result["batch_size"] for result in results})
|
||||
batch_sizes.sort()
|
||||
|
||||
sequence_lengths = list({result["sequence_length"] for result in results})
|
||||
sequence_lengths.sort()
|
||||
|
||||
data_names = []
|
||||
for sequence_length in sequence_lengths:
|
||||
for batch_size in batch_sizes:
|
||||
data_names.append(f"b{batch_size}_s{sequence_length}")
|
||||
|
||||
csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
|
||||
csv_writer.writeheader()
|
||||
|
||||
for description in description_list:
|
||||
row = {}
|
||||
|
||||
sum_latency = {}
|
||||
sum_latency.update(dict.fromkeys(data_names, 0))
|
||||
|
||||
count_latency = {}
|
||||
count_latency.update(dict.fromkeys(data_names, 0))
|
||||
|
||||
for result in results:
|
||||
if result["description"] == description and result[data_field]:
|
||||
headers = {k: v for k, v in result.items() if k in header_names}
|
||||
if not row:
|
||||
row.update(headers)
|
||||
else:
|
||||
for k in header_names:
|
||||
if row[k] != headers[k]:
|
||||
raise RuntimeError("Description shall be unique")
|
||||
|
||||
batch_size = result["batch_size"]
|
||||
sequence_length = result["sequence_length"]
|
||||
key = f"b{batch_size}_s{sequence_length}"
|
||||
|
||||
try:
|
||||
latency = float(result[data_field])
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
sum_latency[key] += latency
|
||||
count_latency[key] += 1
|
||||
|
||||
if row:
|
||||
for key in data_names:
|
||||
if key in count_latency and count_latency[key] > 0:
|
||||
row[key] = sum_latency[key] / count_latency[key]
|
||||
else:
|
||||
row[key] = ""
|
||||
|
||||
csv_writer.writerow(row)
|
||||
|
||||
csv_file.flush()
|
||||
|
||||
|
||||
def run_experiments(use_fp16, batch_size, is_baseline=False):
|
||||
"""Run experiments to compare different algorithms on one batch size"""
|
||||
test_results = run_tests(
|
||||
use_fp16=use_fp16,
|
||||
use_merged_qkv_weights=True,
|
||||
use_half4=False,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
if is_baseline:
|
||||
return test_results
|
||||
|
||||
if use_fp16:
|
||||
test_results += run_tests(
|
||||
use_fp16=use_fp16,
|
||||
use_merged_qkv_weights=True,
|
||||
use_half4=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
test_results += run_tests(
|
||||
use_fp16=use_fp16,
|
||||
use_merged_qkv_weights=False,
|
||||
use_half4=True,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
test_results += run_tests(
|
||||
use_fp16=use_fp16,
|
||||
use_merged_qkv_weights=False,
|
||||
use_half4=False,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
return test_results
|
||||
|
||||
|
||||
def main():
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
benchmark_helper.setup_logger(args.verbose)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
test_results = launch_test(args)
|
||||
time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
csv_filename = f"benchmark_detail_{time_stamp}.csv"
|
||||
output_details(test_results, csv_filename)
|
||||
return
|
||||
|
||||
gpu_list = benchmark_helper.get_gpu_info()
|
||||
logger.info("GPU info: %s", gpu_list)
|
||||
fp16_batch_sizes = [16, 8, 4, 2, 1]
|
||||
fp32_batch_sizes = [4, 2, 1]
|
||||
if gpu_list and gpu_list[0]["total"] >= 32 * 1024 * 1024 * 1024: # 32 GB
|
||||
fp16_batch_sizes = [64, 32, 16, 8, 4, 2, 1]
|
||||
fp32_batch_sizes = [16, 8, 4, 2, 1]
|
||||
|
||||
gpu_name = re.sub(r"(?u)[^-\w.]", "_", gpu_list[0]["name"]) if gpu_list else "gpu"
|
||||
is_baseline = os.environ.get("ORT_LONGFORMER_BASELINE", "0") == "1"
|
||||
experiment_name = f"longformer_base_{gpu_name}" + ("_baseline" if is_baseline else "")
|
||||
logger.info(
|
||||
f"experiment_name={experiment_name}, fp16_batch_sizes={fp16_batch_sizes}, fp32_batch_sizes={fp32_batch_sizes}"
|
||||
)
|
||||
|
||||
total_runs = 1
|
||||
all_results = []
|
||||
for _ in range(total_runs):
|
||||
for batch_size in fp16_batch_sizes:
|
||||
fp16_results = run_experiments(use_fp16=True, batch_size=batch_size, is_baseline=is_baseline)
|
||||
output_details(fp16_results, "longformer_base_fp16.csv")
|
||||
all_results += fp16_results
|
||||
for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
|
||||
output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
|
||||
|
||||
all_results = []
|
||||
for _ in range(total_runs):
|
||||
for batch_size in fp32_batch_sizes:
|
||||
fp32_results = run_experiments(use_fp16=False, batch_size=batch_size, is_baseline=is_baseline)
|
||||
output_details(fp32_results, "longformer_base_fp32.csv")
|
||||
all_results += fp32_results
|
||||
for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
|
||||
output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+413
@@ -0,0 +1,413 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# This script converts Longformer model from huggingface transformers 4.0 or later to ONNX.
|
||||
# It translates LongformerSelfAttention to the LongformerAttention operator in ONNX Runtime.
|
||||
#
|
||||
# Before running this script, prepare a python environment in Linux with PyTorch 1.9.0 and other packages installed.
|
||||
# Then run "python setup.py install" in ./torch_extensions directory. If your python version is not 3.8, you will need
|
||||
# update this script with correct name of longformer_attention.cpython-*.so (search TODO below).
|
||||
#
|
||||
# It is tested in Ubuntu 18.04 with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.9.0, transformers 4.18.0.
|
||||
# Warning: Using PyTorch 1.10 or newer version might encounter issue in exporting, but they are fine for benchmarking.
|
||||
#
|
||||
# Example commands to export longformer base model in Linux:
|
||||
# conda create -n longformer python=3.8
|
||||
# conda activate longformer
|
||||
# python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
# python3 -m pip install coloredlogs flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0
|
||||
# python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu
|
||||
# cd ./torch_extensions
|
||||
# rm -rf build
|
||||
# python setup.py install
|
||||
# cd ..
|
||||
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx
|
||||
# python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx --no_merge_qkv
|
||||
#
|
||||
# GPU is not needed for this script. You can run it in CPU. For --optimize_onnx, you can use either onnxruntime or onnxruntime-gpu package.
|
||||
#
|
||||
# For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or newer version.
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from longformer_helper import PRETRAINED_LONGFORMER_MODELS
|
||||
from onnx import load_model
|
||||
from onnx_model_bert import BertOnnxModel
|
||||
from packaging import version
|
||||
from torch.onnx import register_custom_op_symbolic
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
from torch_onnx_export_helper import torch_onnx_export
|
||||
from transformers import LongformerModel, LongformerSelfAttention
|
||||
|
||||
# Supports format 0 or 1
|
||||
weight_bias_format = 0
|
||||
|
||||
|
||||
@parse_args("v", "v", "v", "v", "v", "v", "v", "i", "i")
|
||||
def my_longformer_attention(
|
||||
g,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
mask,
|
||||
global_weight,
|
||||
global_bias,
|
||||
global_mask,
|
||||
num_heads,
|
||||
window,
|
||||
):
|
||||
return g.op(
|
||||
"com.microsoft::LongformerAttention",
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
mask,
|
||||
global_weight,
|
||||
global_bias,
|
||||
global_mask,
|
||||
num_heads_i=num_heads,
|
||||
window_i=window,
|
||||
)
|
||||
|
||||
|
||||
# namespace is onnxruntime which is registered in longformer_attention.cpp
|
||||
register_custom_op_symbolic("onnxruntime::LongformerAttention", my_longformer_attention, 9)
|
||||
|
||||
# TODO: search the directory to find correct output filename of "python setup.py install" when python version is not 3.8
|
||||
torch.ops.load_library(
|
||||
r"./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so"
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""Parse arguments
|
||||
|
||||
Returns:
|
||||
args: Namespace
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
required=False,
|
||||
type=str,
|
||||
default="longformer-base-4096",
|
||||
help="Checkpoint directory or pre-trained model names in the list: "
|
||||
+ ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--export_padding",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.",
|
||||
)
|
||||
parser.set_defaults(export_padding=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_merge_qkv",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Stack the weights of q, k and v on dimension 0 instead of dimension 1.",
|
||||
)
|
||||
parser.set_defaults(no_merge_qkv=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--optimize_onnx",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use optimizer.py to optimize onnx model.",
|
||||
)
|
||||
parser.set_defaults(optimize_onnx=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
required=False,
|
||||
type=str,
|
||||
default="fp32",
|
||||
choices=["fp32", "fp16"],
|
||||
help="Precision of model to run: fp32 for full precision, fp16 for mixed precision",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# Create a dummy input for ONNX export.
|
||||
def get_dummy_inputs(config, export_padding, device):
|
||||
# When sequence length is multiple of windows size, there is no padding logic in ONNX graph
|
||||
sequence_length = config.attention_window[0] + 1 if export_padding else config.attention_window[0]
|
||||
|
||||
# Create dummy inputs
|
||||
input_ids = torch.arange(sequence_length).unsqueeze(0).to(device)
|
||||
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
|
||||
attention_mask[:, sequence_length - 1] = 0 # last token is masked
|
||||
|
||||
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
|
||||
global_attention_mask[:, 0] = 1 # first token is global token
|
||||
|
||||
return input_ids, attention_mask, global_attention_mask
|
||||
|
||||
|
||||
# A new function to replace LongformerSelfAttention.forward
|
||||
# For transformers 4.0.0
|
||||
def my_longformer_self_attention_forward_4(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
):
|
||||
global_mask = is_index_global_attn.int()
|
||||
# The following check is based on the dummy inputs (only the first token is global).
|
||||
assert (
|
||||
len(global_mask.shape) == 2
|
||||
and global_mask.shape[0] == 1
|
||||
and global_mask.count_nonzero().item() == 1
|
||||
and global_mask.tolist()[0][0] == 1
|
||||
)
|
||||
|
||||
input_mask = is_index_masked.float()
|
||||
# TODO: The filtering value may be -10000.0 or -inf. Check the huggingface implementation.
|
||||
input_mask = input_mask.masked_fill(is_index_masked, -10000.0)
|
||||
# Yet another way to generate input_mask = torch.masked_fill(attention_mask, is_index_global_attn, 0.0)
|
||||
|
||||
# TODO: add postprocessing of ONNX model to calculate based on graph input: input_mask = (attention_mask - 1) * 10000.0
|
||||
# TODO: add postprocessing of ONNX model to use graph input directly: global_mask = global_attention_mask
|
||||
|
||||
# The following check is based on the dummy inputs (only the last token is masked).
|
||||
assert (
|
||||
len(input_mask.shape) == 2
|
||||
and input_mask.shape[0] == 1
|
||||
and input_mask.count_nonzero().item() == 1
|
||||
and input_mask.tolist()[0][-1] == -10000.0
|
||||
)
|
||||
|
||||
weight = torch.stack(
|
||||
(
|
||||
self.query.weight.transpose(0, 1),
|
||||
self.key.weight.transpose(0, 1),
|
||||
self.value.weight.transpose(0, 1),
|
||||
),
|
||||
dim=weight_bias_format,
|
||||
)
|
||||
|
||||
if weight_bias_format == 1:
|
||||
# shape is (hidden_size, 3*hidden_size) for format 1, otherwise (3, hidden_size, hidden_size) by default
|
||||
weight = weight.reshape(self.embed_dim, 3 * self.embed_dim)
|
||||
|
||||
global_weight = torch.stack(
|
||||
(
|
||||
self.query_global.weight.transpose(0, 1),
|
||||
self.key_global.weight.transpose(0, 1),
|
||||
self.value_global.weight.transpose(0, 1),
|
||||
),
|
||||
dim=weight_bias_format,
|
||||
)
|
||||
|
||||
if weight_bias_format == 1:
|
||||
global_weight = global_weight.reshape(self.embed_dim, 3 * self.embed_dim)
|
||||
|
||||
if weight_bias_format == 1:
|
||||
bias = torch.stack((self.query.bias, self.key.bias, self.value.bias), dim=0)
|
||||
bias = bias.reshape(3 * self.embed_dim)
|
||||
global_bias = torch.stack((self.query_global.bias, self.key_global.bias, self.value_global.bias), dim=0)
|
||||
global_bias = global_bias.reshape(3 * self.embed_dim)
|
||||
else:
|
||||
bias = torch.stack(
|
||||
(self.query.bias, self.key.bias, self.value.bias, self.key_global.bias, self.value_global.bias), dim=0
|
||||
)
|
||||
bias = bias.reshape(5 * self.embed_dim)
|
||||
global_bias = self.query_global.bias
|
||||
global_bias = global_bias.reshape(1 * self.embed_dim)
|
||||
|
||||
attn_output = torch.ops.onnxruntime.LongformerAttention(
|
||||
hidden_states,
|
||||
weight,
|
||||
bias,
|
||||
input_mask,
|
||||
global_weight,
|
||||
global_bias,
|
||||
global_mask,
|
||||
self.num_heads,
|
||||
self.one_sided_attn_window_size,
|
||||
)
|
||||
|
||||
assert attn_output.size() == hidden_states.size(), "Unexpected size"
|
||||
|
||||
outputs = (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
# For transformers 4.3.0
|
||||
def my_longformer_self_attention_forward_4_3(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
assert output_attentions is False
|
||||
return my_longformer_self_attention_forward_4(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
)
|
||||
|
||||
|
||||
# For transformers 4.3.2 or later versions
|
||||
def my_longformer_self_attention_forward_4_3_2(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
layer_head_mask=None,
|
||||
is_index_masked=None,
|
||||
is_index_global_attn=None,
|
||||
is_global_attn=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
assert output_attentions is False
|
||||
assert layer_head_mask is None
|
||||
return my_longformer_self_attention_forward_4(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
)
|
||||
|
||||
|
||||
def export_longformer(model: LongformerModel, onnx_model_path: str, export_padding: bool):
|
||||
"""Export longformer model to ONNX
|
||||
|
||||
Args:
|
||||
model (LongformerModel): longformer model
|
||||
onnx_model_path (str): output onnx path
|
||||
export_padding (bool): whether export padding logic to ONNX so that input string can be any length.
|
||||
|
||||
Raises:
|
||||
RuntimeError: This tool requires transformers 4.0.0 or later.
|
||||
RuntimeError: LongformerSelfAttention.forward arguments are different.
|
||||
"""
|
||||
input_ids, attention_mask, global_attention_mask = get_dummy_inputs(
|
||||
model.config, export_padding, device=torch.device("cpu")
|
||||
)
|
||||
|
||||
_ = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
)
|
||||
|
||||
if version.parse(transformers.__version__) < version.parse("4.0.0"):
|
||||
raise RuntimeError("This tool requires transformers 4.0.0 or later.")
|
||||
|
||||
# Here we replace LongformerSelfAttention.forward using our implementation for exporting ONNX model
|
||||
key = " ".join(inspect.getfullargspec(LongformerSelfAttention.forward).args)
|
||||
args_to_func = {
|
||||
"self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3_2,
|
||||
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3,
|
||||
"self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn": my_longformer_self_attention_forward_4,
|
||||
}
|
||||
|
||||
if key not in args_to_func:
|
||||
print(
|
||||
"Current arguments",
|
||||
inspect.getfullargspec(LongformerSelfAttention.forward).args,
|
||||
)
|
||||
raise RuntimeError(
|
||||
"LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)."
|
||||
)
|
||||
|
||||
# Store for restoring later
|
||||
original_forward = LongformerSelfAttention.forward
|
||||
|
||||
LongformerSelfAttention.forward = args_to_func[key]
|
||||
|
||||
example_inputs = (input_ids, attention_mask, global_attention_mask)
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
torch_onnx_export(
|
||||
model,
|
||||
example_inputs,
|
||||
onnx_model_path,
|
||||
opset_version=12,
|
||||
input_names=["input_ids", "attention_mask", "global_attention_mask"],
|
||||
output_names=["last_state", "pooler"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch_size", 1: "sequence_length"},
|
||||
"attention_mask": {0: "batch_size", 1: "sequence_length"},
|
||||
"global_attention_mask": {0: "batch_size", 1: "sequence_length"},
|
||||
"last_state": {0: "batch_size", 1: "sequence_length"},
|
||||
"pooler": {0: "batch_size", 1: "sequence_length"},
|
||||
},
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
)
|
||||
print(f"ONNX model exported to {onnx_model_path}")
|
||||
|
||||
# Restore original implementation:
|
||||
LongformerSelfAttention.forward = original_forward
|
||||
|
||||
|
||||
def optimize_longformer(onnx_model_path: str, fp32_model_path: str, fp16_model_path=None):
|
||||
"""Optimize longformer onnx model
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path of original ONNX model.
|
||||
fp32_model_path (str): path of optimized fp32 model.
|
||||
fp16_model_path (str, optional): path of optimized fp16 model. Defaults to None.
|
||||
"""
|
||||
model = load_model(onnx_model_path, format=None, load_external_data=True)
|
||||
optimizer = BertOnnxModel(model)
|
||||
optimizer.optimize()
|
||||
|
||||
use_external_data_format = False
|
||||
if fp32_model_path:
|
||||
optimizer.save_model_to_file(fp32_model_path, use_external_data_format)
|
||||
print(f"optimized fp32 model saved to {fp32_model_path}")
|
||||
|
||||
if fp16_model_path:
|
||||
optimizer.convert_float_to_float16(keep_io_types=True)
|
||||
optimizer.save_model_to_file(fp16_model_path, use_external_data_format)
|
||||
print(f"optimized fp16 model saved to {fp16_model_path}")
|
||||
|
||||
|
||||
def main(args):
|
||||
model_name = args.model
|
||||
onnx_model_path = model_name + ".onnx"
|
||||
|
||||
global weight_bias_format # noqa: PLW0603
|
||||
weight_bias_format = 0 if args.no_merge_qkv else 1
|
||||
|
||||
model = LongformerModel.from_pretrained(PRETRAINED_LONGFORMER_MODELS[model_name])
|
||||
|
||||
export_longformer(model, onnx_model_path, args.export_padding)
|
||||
|
||||
if args.optimize_onnx or args.precision != "fp32":
|
||||
fp32_model_path = model_name + f"_f{weight_bias_format}" + "_fp32.onnx"
|
||||
fp16_model_path = model_name + f"_f{weight_bias_format}" + "_fp16.onnx" if args.precision == "fp16" else None
|
||||
optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
main(args)
|
||||
+347
@@ -0,0 +1,347 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
# Generate test data for a longformer model, so that we can use onnxruntime_perf_test.exe to evaluate the inference latency.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from bert_test_data import fake_input_ids_data, fake_input_mask_data, output_test_data
|
||||
from onnx import ModelProto, TensorProto
|
||||
from onnx_model import OnnxModel
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="output test data path. If not specified, .",
|
||||
)
|
||||
|
||||
parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
|
||||
|
||||
parser.add_argument(
|
||||
"--sequence_length",
|
||||
required=False,
|
||||
type=int,
|
||||
default=128,
|
||||
help="maximum sequence length of input",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--average_sequence_length",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="average sequence length excluding padding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--random_sequence_length",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="use uniform random instead of fixed sequence length",
|
||||
)
|
||||
parser.set_defaults(random_sequence_length=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--global_tokens",
|
||||
required=False,
|
||||
type=int,
|
||||
default=10,
|
||||
help="number of global tokens",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_ids_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="input name for input ids",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input_mask_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="input name for attention mask",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--global_mask_name",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="input name for global attention mask",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--samples",
|
||||
required=False,
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of test cases to be generated",
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="print verbose information",
|
||||
)
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_longformer_inputs(onnx_file, input_ids_name=None, input_mask_name=None, global_mask_name=None):
|
||||
"""
|
||||
Get graph inputs for longformer model.
|
||||
"""
|
||||
model = ModelProto()
|
||||
with open(onnx_file, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
|
||||
onnx_model = OnnxModel(model)
|
||||
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
|
||||
|
||||
if input_ids_name is not None:
|
||||
input_ids = onnx_model.find_graph_input(input_ids_name)
|
||||
if input_ids is None:
|
||||
raise ValueError(f"Graph does not have input named {input_ids_name}")
|
||||
|
||||
input_mask = None
|
||||
if input_mask_name:
|
||||
input_mask = onnx_model.find_graph_input(input_mask_name)
|
||||
if input_mask is None:
|
||||
raise ValueError(f"Graph does not have input named {input_mask_name}")
|
||||
|
||||
global_mask = None
|
||||
if global_mask_name:
|
||||
global_mask = onnx_model.find_graph_input(global_mask_name)
|
||||
if global_mask is None:
|
||||
raise ValueError(f"Graph does not have input named {global_mask_name}")
|
||||
|
||||
expected_inputs = 1 + (1 if input_mask else 0) + (1 if global_mask else 0)
|
||||
if len(graph_inputs) != expected_inputs:
|
||||
raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
|
||||
|
||||
return input_ids, input_mask, global_mask
|
||||
|
||||
if len(graph_inputs) != 3:
|
||||
raise ValueError(f"Expect the graph to have 3 inputs. Got {len(graph_inputs)}")
|
||||
|
||||
# Try guess the inputs based on naming.
|
||||
input_ids = None
|
||||
input_mask = None
|
||||
global_mask = None
|
||||
for input in graph_inputs:
|
||||
input_name_lower = input.name.lower()
|
||||
if "global" in input_name_lower:
|
||||
global_mask = input
|
||||
elif "mask" in input_name_lower:
|
||||
input_mask = input
|
||||
else:
|
||||
input_ids = input
|
||||
|
||||
if input_ids and input_mask and global_mask:
|
||||
return input_ids, input_mask, global_mask
|
||||
|
||||
raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
|
||||
|
||||
|
||||
def fake_global_mask_data(global_mask, batch_size, sequence_length, num_global_tokens):
|
||||
"""
|
||||
Fake data based on the graph input of segment_ids.
|
||||
Args:
|
||||
segment_ids (TensorProto): graph input of input tensor.
|
||||
Returns:
|
||||
data (np.array): the data for input tensor
|
||||
"""
|
||||
data_type = global_mask.type.tensor_type.elem_type
|
||||
assert data_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
|
||||
|
||||
if num_global_tokens > 0:
|
||||
assert num_global_tokens <= sequence_length
|
||||
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
|
||||
temp = np.ones((batch_size, num_global_tokens), dtype=np.int32)
|
||||
data[: temp.shape[0], : temp.shape[1]] = temp
|
||||
else:
|
||||
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
|
||||
|
||||
if data_type == TensorProto.FLOAT:
|
||||
data = np.float32(data)
|
||||
elif data_type == TensorProto.INT64:
|
||||
data = np.int64(data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def fake_test_data(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
test_cases,
|
||||
dictionary_size,
|
||||
verbose,
|
||||
random_seed,
|
||||
input_ids,
|
||||
input_mask,
|
||||
global_mask,
|
||||
num_global_tokens,
|
||||
average_sequence_length,
|
||||
random_sequence_length,
|
||||
):
|
||||
"""
|
||||
Generate fake input data for test.
|
||||
"""
|
||||
assert input_ids is not None
|
||||
|
||||
np.random.seed(random_seed)
|
||||
random.seed(random_seed)
|
||||
|
||||
all_inputs = []
|
||||
for _ in range(test_cases):
|
||||
input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
|
||||
inputs = {input_ids.name: input_1}
|
||||
|
||||
if input_mask:
|
||||
inputs[input_mask.name] = fake_input_mask_data(
|
||||
input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length
|
||||
)
|
||||
|
||||
if global_mask:
|
||||
inputs[global_mask.name] = fake_global_mask_data(
|
||||
global_mask, batch_size, sequence_length, num_global_tokens
|
||||
)
|
||||
|
||||
if verbose and len(all_inputs) == 0:
|
||||
print("Example inputs", inputs)
|
||||
all_inputs.append(inputs)
|
||||
|
||||
return all_inputs
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
test_cases,
|
||||
seed,
|
||||
verbose,
|
||||
input_ids,
|
||||
input_mask,
|
||||
global_mask,
|
||||
num_global_tokens,
|
||||
average_sequence_length,
|
||||
random_sequence_length,
|
||||
):
|
||||
dictionary_size = 10000
|
||||
all_inputs = fake_test_data(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
test_cases,
|
||||
dictionary_size,
|
||||
verbose,
|
||||
seed,
|
||||
input_ids,
|
||||
input_mask,
|
||||
global_mask,
|
||||
num_global_tokens,
|
||||
average_sequence_length,
|
||||
random_sequence_length,
|
||||
)
|
||||
if len(all_inputs) != test_cases:
|
||||
print("Failed to create test data for test.")
|
||||
return all_inputs
|
||||
|
||||
|
||||
def create_longformer_test_data(
|
||||
model,
|
||||
output_dir,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
test_cases,
|
||||
seed,
|
||||
verbose,
|
||||
input_ids_name,
|
||||
input_mask_name,
|
||||
global_mask_name,
|
||||
num_global_tokens,
|
||||
average_sequence_length,
|
||||
random_sequence_length,
|
||||
):
|
||||
input_ids, input_mask, global_mask = get_longformer_inputs(model, input_ids_name, input_mask_name, global_mask_name)
|
||||
all_inputs = generate_test_data(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
test_cases,
|
||||
seed,
|
||||
verbose,
|
||||
input_ids,
|
||||
input_mask,
|
||||
global_mask,
|
||||
num_global_tokens,
|
||||
average_sequence_length,
|
||||
random_sequence_length,
|
||||
)
|
||||
|
||||
for i, inputs in enumerate(all_inputs):
|
||||
output_test_data(output_dir, i, inputs)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
output_dir = args.output_dir
|
||||
if output_dir is None:
|
||||
# Default output directory is a sub-directory under the directory of model.
|
||||
output_dir = os.path.join(
|
||||
Path(args.model).parent,
|
||||
f"b{args.batch_size}_s{args.sequence_length}_g{args.global_tokens}",
|
||||
)
|
||||
|
||||
if output_dir is not None:
|
||||
# create the output directory if not existed
|
||||
path = Path(output_dir)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
print("Directory existed. test data files will be overwritten.")
|
||||
|
||||
if args.average_sequence_length <= 0:
|
||||
args.average_sequence_length = args.sequence_length
|
||||
|
||||
create_longformer_test_data(
|
||||
args.model,
|
||||
output_dir,
|
||||
args.batch_size,
|
||||
args.sequence_length,
|
||||
args.samples,
|
||||
args.seed,
|
||||
args.verbose,
|
||||
args.input_ids_name,
|
||||
args.input_mask_name,
|
||||
args.global_mask_name,
|
||||
args.global_tokens,
|
||||
args.average_sequence_length,
|
||||
)
|
||||
|
||||
print("Test data is saved to directory:", output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
# This script helps creating dummy inputs for Longformer model.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_LONGFORMER_MODELS = {
|
||||
"longformer-base-4096": "allenai/longformer-base-4096",
|
||||
"longformer-large-4096": "allenai/longformer-large-4096",
|
||||
"longformer-random-tiny": "patrickvonplaten/longformer-random-tiny", # A tiny model for debugging
|
||||
}
|
||||
|
||||
|
||||
class LongformerInputs:
|
||||
def __init__(self, input_ids, attention_mask, global_attention_mask):
|
||||
self.input_ids: torch.LongTensor = input_ids
|
||||
self.attention_mask: torch.FloatTensor | torch.HalfTensor = attention_mask
|
||||
self.global_attention_mask: torch.FloatTensor | torch.HalfTensor = global_attention_mask
|
||||
|
||||
def to_list(self) -> list:
|
||||
return [v for v in [self.input_ids, self.attention_mask, self.global_attention_mask] if v is not None]
|
||||
|
||||
def to_tuple(self) -> tuple:
|
||||
return tuple(v for v in self.to_list())
|
||||
|
||||
def get_ort_inputs(self) -> dict:
|
||||
return {
|
||||
"input_ids": numpy.ascontiguousarray(self.input_ids.cpu().numpy()),
|
||||
"attention_mask": numpy.ascontiguousarray(self.attention_mask.cpu().numpy()),
|
||||
"global_attention_mask": numpy.ascontiguousarray(self.global_attention_mask.cpu().numpy()),
|
||||
}
|
||||
|
||||
|
||||
class LongformerHelper:
|
||||
"""A helper class for Longformer model conversion, inference and verification."""
|
||||
|
||||
@staticmethod
|
||||
def get_dummy_inputs(
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
num_global_tokens: int,
|
||||
device: torch.device,
|
||||
vocab_size: int = 100,
|
||||
) -> LongformerInputs:
|
||||
"""Create random inputs for Longformer model.
|
||||
Returns torch tensors of input_ids, attention_mask and global_attention_mask tensors.
|
||||
"""
|
||||
|
||||
input_ids = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size - 1,
|
||||
size=(batch_size, sequence_length),
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
|
||||
global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
|
||||
global_token_index = list(range(num_global_tokens))
|
||||
global_attention_mask[:, global_token_index] = 1
|
||||
return LongformerInputs(input_ids, attention_mask, global_attention_mask)
|
||||
|
||||
@staticmethod
|
||||
def get_output_shapes(batch_size: int, sequence_length: int, hidden_size: int) -> dict[str, list[int]]:
|
||||
"""Returns a dictionary with output name as key, and shape as value."""
|
||||
return {
|
||||
"last_state": [batch_size, sequence_length, hidden_size],
|
||||
"pooler": [batch_size, sequence_length],
|
||||
}
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+582
@@ -0,0 +1,582 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from benchmark_helper import Precision
|
||||
from fusion_options import AttentionOpType
|
||||
from onnx_model import OnnxModel
|
||||
from packaging import version
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from onnxruntime import __version__ as ort_version
|
||||
|
||||
if version.parse(ort_version) < version.parse("1.22.0"):
|
||||
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer as MatMulNBitsQuantizer
|
||||
else:
|
||||
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
|
||||
|
||||
|
||||
class ConvertPhi2ToONNX:
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
model_class: str = "microsoft/phi-2",
|
||||
cache_dir: str = "./cache",
|
||||
):
|
||||
self.model_class = model_class
|
||||
self.device = device
|
||||
self.cache_dir = cache_dir
|
||||
self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir)
|
||||
self.phi_model = None
|
||||
self.batch_size = 2
|
||||
self.sequence_length = 8
|
||||
self.attn_op_type = None
|
||||
self.precision = None
|
||||
self.block_size = 16
|
||||
self.accuracy_level = None
|
||||
|
||||
def set_quantization_params(self, block_size: int, accuracy_level: int | None):
|
||||
self.block_size = block_size
|
||||
self.accuracy_level = accuracy_level
|
||||
|
||||
def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision):
|
||||
self.attn_op_type = attn_op_type
|
||||
self.precision = precision
|
||||
|
||||
def erase_onnx_model(self, onnx_path: str) -> None:
|
||||
assert onnx_path.endswith(".onnx")
|
||||
if not os.path.exists(onnx_path):
|
||||
return
|
||||
|
||||
model = onnx.load_model(onnx_path, load_external_data=False)
|
||||
onnx_data_path = None
|
||||
for initializer in model.graph.initializer:
|
||||
if initializer.data_location == 1 and initializer.external_data[0].key == "location":
|
||||
onnx_data_path = "./" + initializer.external_data[0].value
|
||||
break
|
||||
logging.info(f"Erasing {onnx_path}...")
|
||||
os.remove(onnx_path)
|
||||
if onnx_data_path is not None:
|
||||
onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path)
|
||||
logging.info(f"Erasing {onnx_data_path}...")
|
||||
os.remove(onnx_data_path)
|
||||
|
||||
def get_phi2_torch_model(self):
|
||||
logging.info("Loading phi2 torch model...")
|
||||
if self.phi_model is not None:
|
||||
return
|
||||
self.phi_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_class, trust_remote_code=True, cache_dir=self.cache_dir
|
||||
)
|
||||
self.phi_model.eval()
|
||||
self.phi_model.to(self.device)
|
||||
|
||||
def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int):
|
||||
input_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.phi_config.vocab_size,
|
||||
size=(batch_size, sequence_length),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.get_phi2_torch_model()
|
||||
torch_inputs = self.phi_model.prepare_inputs_for_generation(
|
||||
input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"]
|
||||
)
|
||||
return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"]
|
||||
|
||||
def dynamo_export(self, onnx_path: str):
|
||||
input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length)
|
||||
self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
|
||||
|
||||
from torch._dynamo import config # noqa: PLC0415
|
||||
|
||||
config.capture_scalar_outputs = True
|
||||
|
||||
logging.info("Exporting Phi2 torch model to ONNX...")
|
||||
torch.onnx.dynamo_export(
|
||||
self.phi_model,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
|
||||
).save(onnx_path)
|
||||
onnx.checker.check_model(onnx_path)
|
||||
onnx.shape_inference.infer_shapes_path(onnx_path)
|
||||
|
||||
def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
|
||||
from fusion_options import FusionOptions # noqa: PLC0415
|
||||
from optimizer import optimize_model # noqa: PLC0415
|
||||
|
||||
optimization_options = FusionOptions("phi")
|
||||
optimization_options.set_attention_op_type(self.attn_op_type)
|
||||
optimizer = optimize_model(
|
||||
onnx_path,
|
||||
model_type="phi",
|
||||
num_heads=self.phi_config.num_attention_heads,
|
||||
hidden_size=self.phi_config.hidden_size,
|
||||
opt_level=0,
|
||||
optimization_options=optimization_options,
|
||||
only_onnxruntime=False,
|
||||
)
|
||||
|
||||
fused_op_count = optimizer.get_fused_operator_statistics()
|
||||
if optimizer.is_fully_optimized(fused_op_count):
|
||||
logging.info("Model is fully optimized.")
|
||||
else:
|
||||
logging.info("Model is not fully optimized.")
|
||||
|
||||
if self.precision == Precision.FLOAT32:
|
||||
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
|
||||
return
|
||||
|
||||
if (
|
||||
self.precision == Precision.FLOAT16 or self.precision == Precision.INT4
|
||||
) and self.attn_op_type != AttentionOpType.MultiHeadAttention:
|
||||
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
|
||||
node_block_list = (
|
||||
[
|
||||
"Attention_29",
|
||||
"Attention_30",
|
||||
"Attention_31",
|
||||
]
|
||||
if self.attn_op_type != AttentionOpType.PagedAttention
|
||||
else []
|
||||
) # TODO: temp setting for paged attention
|
||||
logging.info("Converting onnx model to float16/bfloat16...")
|
||||
optimizer.convert_float_to_float16(
|
||||
keep_io_types=False,
|
||||
node_block_list=node_block_list,
|
||||
use_symbolic_shape_infer=True,
|
||||
use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention,
|
||||
)
|
||||
logging.info("Converting onnx model to float16/bfloat16 done.")
|
||||
|
||||
if self.precision == Precision.FLOAT16:
|
||||
optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True)
|
||||
return
|
||||
else:
|
||||
assert self.precision == Precision.INT4
|
||||
quant = MatMulNBitsQuantizer(
|
||||
model=optimizer.model,
|
||||
block_size=self.block_size,
|
||||
is_symmetric=True,
|
||||
accuracy_level=self.accuracy_level,
|
||||
)
|
||||
quant.process()
|
||||
quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True)
|
||||
|
||||
# This function currently only works for phi2 model
|
||||
def convert_to_use_cuda_graph(self, in_onnx_path: str, out_onnx_path: str):
|
||||
onnx_model = OnnxModel(onnx.load(in_onnx_path, load_external_data=True))
|
||||
|
||||
from onnx import TensorProto, helper # noqa: PLC0415
|
||||
|
||||
graph = onnx_model.graph()
|
||||
new_inputs = []
|
||||
for vi in graph.input:
|
||||
if "attention_mask" in vi.name:
|
||||
vi_seqlen_k = helper.make_tensor_value_info(
|
||||
"seqlens_k",
|
||||
elem_type=TensorProto.INT32,
|
||||
shape=["batch_size"],
|
||||
)
|
||||
vi_total_seq_len = helper.make_tensor_value_info(
|
||||
"total_sequence_length",
|
||||
elem_type=TensorProto.INT32,
|
||||
shape=[1],
|
||||
)
|
||||
new_inputs.extend([vi_seqlen_k, vi_total_seq_len])
|
||||
else:
|
||||
new_inputs.append(vi)
|
||||
|
||||
graph.ClearField("input")
|
||||
graph.input.extend(new_inputs)
|
||||
|
||||
gqas = onnx_model.get_nodes_by_op_type("GroupQueryAttention")
|
||||
gqa = gqas[0]
|
||||
seqlens_path = onnx_model.match_parent_path(
|
||||
gqa,
|
||||
["Cast", "Sub", "ReduceSum", "Cast"],
|
||||
[5, 0, 0, 0],
|
||||
)
|
||||
if seqlens_path is None:
|
||||
raise RuntimeError("Failed to find seqlens path for GroupQueryAttention node.")
|
||||
total_seq_len_path = onnx_model.match_parent_path(
|
||||
gqa,
|
||||
["Cast", "Gather", "Shape"],
|
||||
[6, 0, 0],
|
||||
)
|
||||
if total_seq_len_path is None:
|
||||
raise RuntimeError("Failed to find total_seq_len path for GroupQueryAttention node.")
|
||||
onnx_model.remove_nodes(seqlens_path)
|
||||
onnx_model.remove_nodes(total_seq_len_path)
|
||||
|
||||
for gqa in gqas:
|
||||
gqa.input[5] = "seqlens_k"
|
||||
gqa.input[6] = "total_sequence_length"
|
||||
|
||||
onnx_model.save(onnx_model.model, out_onnx_path, save_as_external_data=True)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--fp32_cpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp32 ONNX model for CPU",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_cpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate int4 ONNX model for CPU",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp32_gpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp32 ONNX model for Nvidia GPUs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_gpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp16 ONNX model for Nvidia GPUs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_gpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate int4 ONNX model for Nvidia GPUs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_gpu_sm8x",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_gpu_sm8x",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16_vllm",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate fp16 ONNX model for ORT VLLM",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_vllm",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Generate int4 ONNX model for ORT VLLM",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_cuda_graph",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use CUDA Graph in decoding process",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Overwrite existing ONNX models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./cache",
|
||||
help="The cache directory for the pytorch model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device_id",
|
||||
required=False,
|
||||
type=int,
|
||||
default=0,
|
||||
help="The device id for the pytorch model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--run_example",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Run ORT inference example",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--run_benchmark",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Run ORT benchmark",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_export",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Skip exporting ONNX model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
help="The output directory for the ONNX models",
|
||||
default="phi2_onnx_models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--block_size",
|
||||
required=False,
|
||||
default=16,
|
||||
type=int,
|
||||
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py for details.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--int4_accuracy_level",
|
||||
required=False,
|
||||
type=int,
|
||||
help="Accuracy level of the 4-bit quantized MatMul computation. "
|
||||
"Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
|
||||
"(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir)
|
||||
converter.set_quantization_params(args.block_size, args.int4_accuracy_level)
|
||||
|
||||
output_dir = args.output_dir
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
original_onnx_path = os.path.join(output_dir, "phi2_original.onnx")
|
||||
|
||||
if not args.skip_export:
|
||||
if not os.path.exists(original_onnx_path) or args.overwrite:
|
||||
converter.dynamo_export(original_onnx_path)
|
||||
|
||||
model_type_to_args = {
|
||||
"fp32_cpu": (
|
||||
AttentionOpType.MultiHeadAttention,
|
||||
Precision.FLOAT32,
|
||||
os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"),
|
||||
),
|
||||
"int4_cpu": (
|
||||
AttentionOpType.MultiHeadAttention,
|
||||
Precision.INT4,
|
||||
os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"),
|
||||
),
|
||||
"fp32_gpu": (
|
||||
AttentionOpType.Attention,
|
||||
Precision.FLOAT32,
|
||||
os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"),
|
||||
),
|
||||
"fp16_gpu": (
|
||||
AttentionOpType.Attention,
|
||||
Precision.FLOAT16,
|
||||
os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"),
|
||||
),
|
||||
"int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")),
|
||||
"fp16_gpu_sm8x": (
|
||||
AttentionOpType.GroupQueryAttention,
|
||||
Precision.FLOAT16,
|
||||
os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"),
|
||||
),
|
||||
"int4_gpu_sm8x": (
|
||||
AttentionOpType.GroupQueryAttention,
|
||||
Precision.INT4,
|
||||
os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"),
|
||||
),
|
||||
"fp16_vllm": (
|
||||
AttentionOpType.PagedAttention,
|
||||
Precision.FLOAT16,
|
||||
os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"),
|
||||
),
|
||||
"int4_vllm": (
|
||||
AttentionOpType.PagedAttention,
|
||||
Precision.INT4,
|
||||
os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"),
|
||||
),
|
||||
}
|
||||
|
||||
if not args.skip_export:
|
||||
from multiprocessing import Process # noqa: PLC0415
|
||||
|
||||
def run_optimize_phi2_onnx(
|
||||
converter: ConvertPhi2ToONNX,
|
||||
original_onnx_path: str,
|
||||
attention_type: AttentionOpType,
|
||||
precision: Precision,
|
||||
optimized_onnx_path: str,
|
||||
):
|
||||
converter.init_attn_type_and_precision(attention_type, precision)
|
||||
converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path)
|
||||
if args.use_cuda_graph:
|
||||
assert args.fp16_gpu_sm8x or args.int4_gpu_sm8x
|
||||
converter.convert_to_use_cuda_graph(optimized_onnx_path, optimized_onnx_path)
|
||||
|
||||
processes = []
|
||||
if args.fp32_cpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.int4_cpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.fp32_gpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.fp16_gpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.int4_gpu:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"])
|
||||
)
|
||||
)
|
||||
|
||||
if args.fp16_gpu_sm8x:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx,
|
||||
args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]),
|
||||
)
|
||||
)
|
||||
|
||||
if args.int4_gpu_sm8x:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx,
|
||||
args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]),
|
||||
)
|
||||
)
|
||||
|
||||
if args.fp16_vllm:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx,
|
||||
args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]),
|
||||
)
|
||||
)
|
||||
|
||||
if args.int4_vllm:
|
||||
processes.append(
|
||||
Process(
|
||||
target=run_optimize_phi2_onnx,
|
||||
args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]),
|
||||
)
|
||||
)
|
||||
|
||||
[p.start() for p in processes]
|
||||
[p.join() for p in processes]
|
||||
|
||||
if args.run_example or args.run_benchmark:
|
||||
from inference_example import run_phi2 # noqa: PLC0415
|
||||
|
||||
if args.fp16_gpu_sm8x:
|
||||
logging.info("Running fp16_gpu_sm8x example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2],
|
||||
use_buffer_share=True,
|
||||
device_id=args.device_id,
|
||||
use_step=True,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
run_benchmark=args.run_benchmark,
|
||||
)
|
||||
if args.int4_gpu_sm8x:
|
||||
logging.info("Running int4_gpu_sm8x example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2],
|
||||
use_buffer_share=True,
|
||||
device_id=args.device_id,
|
||||
use_step=True,
|
||||
use_cuda_graph=args.use_cuda_graph,
|
||||
run_benchmark=args.run_benchmark,
|
||||
)
|
||||
if args.fp32_gpu:
|
||||
logging.info("Running fp32_gpu example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["fp32_gpu"][2],
|
||||
use_buffer_share=False,
|
||||
device_id=args.device_id,
|
||||
packed_kv=True,
|
||||
use_fp16=False,
|
||||
run_benchmark=args.run_benchmark,
|
||||
)
|
||||
if args.fp16_gpu:
|
||||
logging.info("Running fp16_gpu example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["fp16_gpu"][2],
|
||||
use_buffer_share=False,
|
||||
device_id=args.device_id,
|
||||
packed_kv=True,
|
||||
run_benchmark=args.run_benchmark,
|
||||
)
|
||||
if args.int4_gpu:
|
||||
logging.info("Running int4_gpu example...")
|
||||
run_phi2(
|
||||
onnx_model_path=model_type_to_args["int4_gpu"][2],
|
||||
use_buffer_share=False,
|
||||
device_id=args.device_id,
|
||||
packed_kv=True,
|
||||
run_benchmark=args.run_benchmark,
|
||||
)
|
||||
if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm:
|
||||
raise NotImplementedError("CPU/vllm inference example is not implemented yet.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+414
@@ -0,0 +1,414 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
pt_to_np = {
|
||||
"torch.int32": np.int32,
|
||||
"torch.int64": np.int64,
|
||||
"torch.float32": np.float32,
|
||||
"torch.float16": np.float16,
|
||||
}
|
||||
|
||||
|
||||
def cuda_memcpy(dst, src):
|
||||
from cuda import cudart # noqa: PLC0415
|
||||
|
||||
cudart.cudaMemcpy(
|
||||
dst.data_ptr(),
|
||||
src.data_ptr(),
|
||||
src.element_size() * src.nelement(),
|
||||
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice,
|
||||
)
|
||||
|
||||
|
||||
class ORTGenerator:
|
||||
def __init__(self, decoder_path):
|
||||
self.onnx_decoder_path = decoder_path
|
||||
self.num_heads = 32
|
||||
self.head_size = 80
|
||||
self.num_layers = 32
|
||||
self.max_sequence_length = 2048
|
||||
self.device_id = 0
|
||||
self.use_cuda_graph = False
|
||||
self.use_traced_inputs = False
|
||||
self.static_inputs_map = {}
|
||||
|
||||
def append_static_inputs(self, batch_size):
|
||||
# Only use this function with GQA and with use_cuda_graph=True
|
||||
if batch_size in self.static_inputs_map:
|
||||
return
|
||||
|
||||
cpu_device = torch.device("cpu")
|
||||
cuda_device = torch.device("cuda", self.device_id)
|
||||
|
||||
static_io = {}
|
||||
static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device)
|
||||
static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device)
|
||||
static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device)
|
||||
static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device)
|
||||
|
||||
cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size)
|
||||
for i in range(self.num_layers):
|
||||
cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16)
|
||||
static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()})
|
||||
|
||||
static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device)
|
||||
|
||||
self.static_inputs_map[batch_size] = static_io
|
||||
|
||||
def get_initial_inputs_and_outputs(self, encodings_dict):
|
||||
self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32
|
||||
|
||||
input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32)
|
||||
attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32)
|
||||
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
self.use_traced_inputs = (
|
||||
self.use_cuda_graph
|
||||
and (batch_size in self.static_inputs_map)
|
||||
and self.use_buffer_share
|
||||
and not self.packed_kv
|
||||
)
|
||||
|
||||
step = (
|
||||
torch.tensor([0], device=self.device, dtype=torch.int64)
|
||||
if not self.use_traced_inputs
|
||||
else self.static_inputs_map[batch_size]["step"]
|
||||
)
|
||||
|
||||
seqlens_k = (
|
||||
torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32)
|
||||
if not self.use_traced_inputs
|
||||
else self.static_inputs_map[batch_size]["seqlens_k"]
|
||||
)
|
||||
cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32))
|
||||
|
||||
total_seq_length = (
|
||||
torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32)
|
||||
if not self.use_traced_inputs
|
||||
else self.static_inputs_map[batch_size]["total_sequence_length"]
|
||||
)
|
||||
total_seq_length[0] = sequence_length
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids.contiguous(),
|
||||
"attention_mask": attention_mask.contiguous(),
|
||||
}
|
||||
|
||||
if self.use_step:
|
||||
inputs["step"] = step.contiguous()
|
||||
|
||||
if self.use_cuda_graph:
|
||||
inputs["seqlens_k"] = seqlens_k.contiguous()
|
||||
inputs["total_sequence_length"] = total_seq_length.contiguous()
|
||||
del inputs["attention_mask"]
|
||||
|
||||
past_seq_length = self.max_sequence_length if self.use_buffer_share else 0
|
||||
past_shape = (
|
||||
(2, batch_size, self.num_heads, past_seq_length, self.head_size)
|
||||
if self.packed_kv
|
||||
else (batch_size, self.num_heads, past_seq_length, self.head_size)
|
||||
)
|
||||
|
||||
if not self.use_traced_inputs:
|
||||
for i in range(self.num_layers):
|
||||
past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype)
|
||||
(
|
||||
inputs.update({f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()})
|
||||
if not self.packed_kv
|
||||
else inputs.update({f"past_{i}": past.contiguous()})
|
||||
)
|
||||
else:
|
||||
for i in range(self.num_layers):
|
||||
inputs.update(
|
||||
{
|
||||
f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(),
|
||||
f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(),
|
||||
}
|
||||
)
|
||||
|
||||
logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype)
|
||||
outputs = {"logits": logits.contiguous()}
|
||||
|
||||
if not self.use_buffer_share:
|
||||
present_shape = (
|
||||
(2, batch_size, self.num_heads, sequence_length, self.head_size)
|
||||
if self.packed_kv
|
||||
else (batch_size, self.num_heads, sequence_length, self.head_size)
|
||||
)
|
||||
for i in range(self.num_layers):
|
||||
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
|
||||
(
|
||||
outputs.update(
|
||||
{f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()}
|
||||
)
|
||||
if not self.packed_kv
|
||||
else outputs.update({f"present_{i}": present.contiguous()})
|
||||
)
|
||||
|
||||
return inputs, outputs
|
||||
|
||||
def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict):
|
||||
io_binding = model.io_binding()
|
||||
device = None
|
||||
|
||||
for k, v in inputs.items():
|
||||
io_binding.bind_input(
|
||||
name=k,
|
||||
device_type=v.device.type,
|
||||
device_id=0 if v.device.type == "cpu" else v.device.index,
|
||||
element_type=pt_to_np[repr(v.dtype)],
|
||||
shape=tuple(v.shape),
|
||||
buffer_ptr=v.data_ptr(),
|
||||
)
|
||||
device = v.device
|
||||
|
||||
for output in model.get_outputs():
|
||||
name = output.name
|
||||
if self.use_buffer_share and "present" in name:
|
||||
v = inputs[name.replace("present", "past")]
|
||||
io_binding.bind_output(
|
||||
name=name,
|
||||
device_type=v.device.type,
|
||||
device_id=v.device.index,
|
||||
element_type=(np.float16 if self.use_fp16 else np.float32),
|
||||
shape=tuple(v.shape),
|
||||
buffer_ptr=v.data_ptr(),
|
||||
)
|
||||
else:
|
||||
v = outputs[name]
|
||||
io_binding.bind_output(
|
||||
name=name,
|
||||
device_type=device.type,
|
||||
device_id=0 if device.type == "cpu" else device.index,
|
||||
element_type=(np.float16 if self.use_fp16 else np.float32),
|
||||
shape=tuple(v.shape),
|
||||
buffer_ptr=v.data_ptr(),
|
||||
)
|
||||
|
||||
return io_binding
|
||||
|
||||
def create_session(
|
||||
self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False
|
||||
):
|
||||
self.device_id = device_id
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.log_verbosity_level = 4
|
||||
sess_options.log_severity_level = 4
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
ep = (
|
||||
("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph})
|
||||
if self.device_id >= 0
|
||||
else "CPUExecutionProvider"
|
||||
)
|
||||
self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep])
|
||||
self.ro = ort.RunOptions()
|
||||
|
||||
self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu")
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_buffer_share = use_buffer_share
|
||||
self.packed_kv = packed_kv
|
||||
self.use_step = use_step
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
||||
self.tokenizer.pad_token = "[PAD]"
|
||||
|
||||
def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False):
|
||||
inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict)
|
||||
|
||||
all_token_ids = inputs["input_ids"].clone()
|
||||
batch_size, sequence_length = all_token_ids.shape
|
||||
|
||||
current_length = sequence_length
|
||||
has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool)
|
||||
|
||||
if benchmark:
|
||||
latency = []
|
||||
|
||||
prompt_run = True
|
||||
while current_length < max_length:
|
||||
io_binding = self.apply_io_binding(self.sess, inputs, outputs)
|
||||
|
||||
if benchmark:
|
||||
start = time.time()
|
||||
|
||||
io_binding.synchronize_inputs()
|
||||
if prompt_run:
|
||||
if self.use_cuda_graph:
|
||||
# Disable CUDA graph for the prompt run
|
||||
self.ro.add_run_config_entry("gpu_graph_id", "-1")
|
||||
self.sess.run_with_iobinding(io_binding, self.ro)
|
||||
if self.use_cuda_graph:
|
||||
# Enable CUDA graph for the decoding run
|
||||
self.ro.add_run_config_entry(
|
||||
"gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1"
|
||||
)
|
||||
prompt_run = False
|
||||
else:
|
||||
self.sess.run_with_iobinding(io_binding, self.ro)
|
||||
io_binding.synchronize_outputs()
|
||||
|
||||
if benchmark:
|
||||
end = time.time()
|
||||
latency.append(end - start)
|
||||
|
||||
# Sample with argmax (greedy search)
|
||||
next_token_logits = outputs["logits"][:, -1, :]
|
||||
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
||||
|
||||
# Check if we previously reached EOS token id or if generated token id is EOS token id
|
||||
has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id
|
||||
|
||||
# Determine which new tokens to add to list of all token ids
|
||||
# Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't)
|
||||
tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1])
|
||||
all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1)
|
||||
|
||||
# Return early if all batch entries have reached EOS token id
|
||||
if torch.all(has_eos):
|
||||
break
|
||||
|
||||
# Update inputs for next inference run
|
||||
current_length += 1
|
||||
|
||||
inputs["input_ids"] = tokens_to_add.to(torch.int32)
|
||||
if self.use_traced_inputs:
|
||||
cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"])
|
||||
inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"]
|
||||
|
||||
if self.use_step:
|
||||
inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64)
|
||||
if self.use_traced_inputs:
|
||||
cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"])
|
||||
inputs["step"] = self.static_inputs_map[batch_size]["step"]
|
||||
|
||||
if self.use_cuda_graph:
|
||||
previous_seqlens_k = inputs["seqlens_k"]
|
||||
inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32)
|
||||
inputs["total_sequence_length"][0] = current_length
|
||||
if self.use_traced_inputs:
|
||||
cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"])
|
||||
inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"]
|
||||
self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0]
|
||||
inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"]
|
||||
else:
|
||||
inputs["attention_mask"] = torch.cat(
|
||||
[inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1
|
||||
).to(torch.int32)
|
||||
|
||||
# Set logits to zeros for next inference run and re-use memory buffer
|
||||
if outputs["logits"].shape[1] != 1:
|
||||
outputs["logits"] = outputs["logits"][:, :1, :].contiguous()
|
||||
if self.use_traced_inputs:
|
||||
outputs["logits"] = self.static_inputs_map[batch_size]["logits"]
|
||||
outputs["logits"].zero_()
|
||||
|
||||
if not self.use_buffer_share:
|
||||
for i in range(self.num_layers):
|
||||
if not self.packed_kv:
|
||||
inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"]
|
||||
inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"]
|
||||
else:
|
||||
inputs[f"past_{i}"] = outputs[f"present_{i}"]
|
||||
|
||||
new_sequence_length = inputs["attention_mask"].shape[1]
|
||||
present_shape = (
|
||||
(2, batch_size, self.num_heads, new_sequence_length, self.head_size)
|
||||
if self.packed_kv
|
||||
else (batch_size, self.num_heads, new_sequence_length, self.head_size)
|
||||
)
|
||||
for i in range(self.num_layers):
|
||||
present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype)
|
||||
(
|
||||
outputs.update(
|
||||
{
|
||||
f"present_key_{i}": present.contiguous(),
|
||||
f"present_value_{i}": present.clone().contiguous(),
|
||||
}
|
||||
)
|
||||
if not self.packed_kv
|
||||
else outputs.update({f"present_{i}": present.contiguous()})
|
||||
)
|
||||
|
||||
if benchmark:
|
||||
print(
|
||||
f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}"
|
||||
)
|
||||
print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms")
|
||||
return
|
||||
|
||||
texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True)
|
||||
return texts
|
||||
|
||||
def generate(self, prompt, max_length, cuda_graph_annotation):
|
||||
encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True)
|
||||
|
||||
return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation)
|
||||
|
||||
def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation):
|
||||
batch_size, sequence_length = prompt_shape
|
||||
max_length = sequence_length + token_num
|
||||
|
||||
encodings_dict = {}
|
||||
encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist()
|
||||
encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist()
|
||||
|
||||
# Warm up run
|
||||
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False)
|
||||
|
||||
# Benchmark run
|
||||
self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True)
|
||||
|
||||
|
||||
def run_phi2(
|
||||
onnx_model_path,
|
||||
use_buffer_share,
|
||||
device_id,
|
||||
packed_kv=False,
|
||||
use_fp16=True,
|
||||
use_step=False,
|
||||
use_cuda_graph=False,
|
||||
run_benchmark=False,
|
||||
):
|
||||
generator = ORTGenerator(onnx_model_path)
|
||||
generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph)
|
||||
|
||||
def simple_run(prompt):
|
||||
example_batch_size = len(prompt)
|
||||
if use_cuda_graph:
|
||||
generator.append_static_inputs(batch_size=example_batch_size)
|
||||
texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size)
|
||||
|
||||
for i in range(len(texts)):
|
||||
print("Prompt: ", prompt[i])
|
||||
print("Texts: ", texts[i])
|
||||
|
||||
prompt = [
|
||||
'''```python
|
||||
def print_prime(n):
|
||||
"""
|
||||
Print all primes between 1 and n
|
||||
"""'''
|
||||
]
|
||||
|
||||
if not run_benchmark:
|
||||
simple_run(prompt)
|
||||
|
||||
# Run simple benchmark. Time the decoder only.
|
||||
if run_benchmark:
|
||||
token_num = 32
|
||||
for batch_size in [1, 2, 4, 8]:
|
||||
generator.append_static_inputs(batch_size)
|
||||
for sequence_length in [16, 512]:
|
||||
prompt_shape = (batch_size, sequence_length)
|
||||
generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size)
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+638
@@ -0,0 +1,638 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Benchmark performance of SAM2 encoder with ORT or PyTorch. See benchmark_sam2.sh for usage.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import statistics
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from image_decoder import SAM2ImageDecoder
|
||||
from image_encoder import SAM2ImageEncoder
|
||||
from sam2_utils import decoder_shape_dict, encoder_shape_dict, load_sam2_model
|
||||
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession
|
||||
|
||||
|
||||
class TestConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
onnx_path: str,
|
||||
sam2_dir: str,
|
||||
device: torch.device,
|
||||
component: str = "image_encoder",
|
||||
provider="CPUExecutionProvider",
|
||||
torch_compile_mode="max-autotune",
|
||||
batch_size: int = 1,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
num_labels: int = 1,
|
||||
num_points: int = 1,
|
||||
num_masks: int = 1,
|
||||
multi_mask_output: bool = False,
|
||||
use_tf32: bool = True,
|
||||
enable_cuda_graph: bool = False,
|
||||
dtype=torch.float32,
|
||||
prefer_nhwc: bool = False,
|
||||
warm_up: int = 5,
|
||||
enable_nvtx_profile: bool = False,
|
||||
enable_ort_profile: bool = False,
|
||||
enable_torch_profile: bool = False,
|
||||
repeats: int = 1000,
|
||||
verbose: bool = False,
|
||||
):
|
||||
assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
|
||||
assert height >= 160 and height <= 4096
|
||||
assert width >= 160 and width <= 4096
|
||||
|
||||
self.model_type = model_type
|
||||
self.onnx_path = onnx_path
|
||||
self.sam2_dir = sam2_dir
|
||||
self.component = component
|
||||
self.provider = provider
|
||||
self.torch_compile_mode = torch_compile_mode
|
||||
self.batch_size = batch_size
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.num_labels = num_labels
|
||||
self.num_points = num_points
|
||||
self.num_masks = num_masks
|
||||
self.multi_mask_output = multi_mask_output
|
||||
self.device = device
|
||||
self.use_tf32 = use_tf32
|
||||
self.enable_cuda_graph = enable_cuda_graph
|
||||
self.dtype = dtype
|
||||
self.prefer_nhwc = prefer_nhwc
|
||||
self.warm_up = warm_up
|
||||
self.enable_nvtx_profile = enable_nvtx_profile
|
||||
self.enable_ort_profile = enable_ort_profile
|
||||
self.enable_torch_profile = enable_torch_profile
|
||||
self.repeats = repeats
|
||||
self.verbose = verbose
|
||||
|
||||
if self.component == "image_encoder":
|
||||
assert self.height == 1024 and self.width == 1024, "Only image size 1024x1024 is allowed for image encoder."
|
||||
|
||||
def __repr__(self):
|
||||
return f"{vars(self)}"
|
||||
|
||||
def shape_dict(self) -> Mapping[str, list[int]]:
|
||||
if self.component == "image_encoder":
|
||||
return encoder_shape_dict(self.batch_size, self.height, self.width)
|
||||
else:
|
||||
return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks)
|
||||
|
||||
def random_inputs(self) -> Mapping[str, torch.Tensor]:
|
||||
dtype = self.dtype
|
||||
if self.component == "image_encoder":
|
||||
return {"image": torch.randn(self.batch_size, 3, self.height, self.width, dtype=dtype, device=self.device)}
|
||||
else:
|
||||
return {
|
||||
"image_features_0": torch.rand(1, 32, 256, 256, dtype=dtype, device=self.device),
|
||||
"image_features_1": torch.rand(1, 64, 128, 128, dtype=dtype, device=self.device),
|
||||
"image_embeddings": torch.rand(1, 256, 64, 64, dtype=dtype, device=self.device),
|
||||
"point_coords": torch.randint(
|
||||
0, 1024, (self.num_labels, self.num_points, 2), dtype=dtype, device=self.device
|
||||
),
|
||||
"point_labels": torch.randint(
|
||||
0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device
|
||||
),
|
||||
"input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=dtype, device=self.device),
|
||||
"has_input_masks": torch.ones(self.num_labels, dtype=dtype, device=self.device),
|
||||
"original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device),
|
||||
}
|
||||
|
||||
|
||||
def create_ort_session(config: TestConfig, session_options=None) -> InferenceSession:
|
||||
if config.verbose:
|
||||
print(f"create session for {vars(config)}")
|
||||
|
||||
if config.provider == "CUDAExecutionProvider":
|
||||
device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
|
||||
provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
|
||||
provider_options["use_tf32"] = int(config.use_tf32)
|
||||
if config.prefer_nhwc:
|
||||
provider_options["prefer_nhwc"] = 1
|
||||
providers = [(config.provider, provider_options), "CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
|
||||
ort_session = InferenceSession(config.onnx_path, session_options, providers=providers)
|
||||
return ort_session
|
||||
|
||||
|
||||
def create_session(config: TestConfig, session_options=None) -> CudaSession:
|
||||
ort_session = create_ort_session(config, session_options)
|
||||
cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
|
||||
cuda_session.allocate_buffers(config.shape_dict())
|
||||
return cuda_session
|
||||
|
||||
|
||||
class OrtTestSession:
|
||||
"""A wrapper of ORT session to test relevance and performance."""
|
||||
|
||||
def __init__(self, config: TestConfig, session_options=None):
|
||||
self.ort_session = create_session(config, session_options)
|
||||
self.feed_dict = config.random_inputs()
|
||||
|
||||
def infer(self):
|
||||
return self.ort_session.infer(self.feed_dict)
|
||||
|
||||
|
||||
def measure_latency(cuda_session: CudaSession, input_dict):
|
||||
start = time.time()
|
||||
_ = cuda_session.infer(input_dict)
|
||||
end = time.time()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_torch(config: TestConfig):
|
||||
device_type = config.device.type
|
||||
is_cuda = device_type == "cuda"
|
||||
|
||||
# Turn on TF32 for Ampere GPUs which could help when data type is float32.
|
||||
if is_cuda and torch.cuda.get_device_properties(0).major >= 8 and config.use_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
enabled_auto_cast = is_cuda and config.dtype != torch.float32
|
||||
ort_inputs = config.random_inputs()
|
||||
|
||||
with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
|
||||
sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
|
||||
if config.component == "image_encoder":
|
||||
if is_cuda and config.torch_compile_mode != "none":
|
||||
sam2_model.image_encoder.forward = torch.compile(
|
||||
sam2_model.image_encoder.forward,
|
||||
mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run.
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
image_shape = config.shape_dict()["image"]
|
||||
img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model)
|
||||
|
||||
if is_cuda and config.torch_compile_mode != "none":
|
||||
print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")
|
||||
|
||||
for _ in range(config.warm_up):
|
||||
_image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
|
||||
|
||||
if is_cuda and config.enable_nvtx_profile:
|
||||
import nvtx # noqa: PLC0415
|
||||
from cuda import cudart # noqa: PLC0415
|
||||
|
||||
cudart.cudaProfilerStart()
|
||||
print("Start nvtx profiling on encoder ...")
|
||||
with nvtx.annotate("one_run"):
|
||||
sam2_encoder(img, enable_nvtx_profile=True)
|
||||
cudart.cudaProfilerStop()
|
||||
|
||||
if is_cuda and config.enable_torch_profile:
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
) as prof:
|
||||
print("Start torch profiling on encoder ...")
|
||||
with torch.profiler.record_function("encoder"):
|
||||
sam2_encoder(img)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
prof.export_chrome_trace("torch_image_encoder.json")
|
||||
|
||||
if config.repeats == 0:
|
||||
return
|
||||
|
||||
print(f"Start {config.repeats} runs of performance tests...")
|
||||
start = time.time()
|
||||
for _ in range(config.repeats):
|
||||
_image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
|
||||
if is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
torch_inputs = (
|
||||
ort_inputs["image_features_0"],
|
||||
ort_inputs["image_features_1"],
|
||||
ort_inputs["image_embeddings"],
|
||||
ort_inputs["point_coords"],
|
||||
ort_inputs["point_labels"],
|
||||
ort_inputs["input_masks"],
|
||||
ort_inputs["has_input_masks"],
|
||||
ort_inputs["original_image_size"],
|
||||
)
|
||||
|
||||
sam2_decoder = SAM2ImageDecoder(
|
||||
sam2_model,
|
||||
multimask_output=config.multi_mask_output,
|
||||
)
|
||||
|
||||
if is_cuda and config.torch_compile_mode != "none":
|
||||
sam2_decoder.forward = torch.compile(
|
||||
sam2_decoder.forward,
|
||||
mode=config.torch_compile_mode,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
# warm up
|
||||
for _ in range(config.warm_up):
|
||||
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
|
||||
|
||||
if is_cuda and config.enable_nvtx_profile:
|
||||
import nvtx # noqa: PLC0415
|
||||
from cuda import cudart # noqa: PLC0415
|
||||
|
||||
cudart.cudaProfilerStart()
|
||||
print("Start nvtx profiling on decoder...")
|
||||
with nvtx.annotate("one_run"):
|
||||
sam2_decoder(*torch_inputs, enable_nvtx_profile=True)
|
||||
cudart.cudaProfilerStop()
|
||||
|
||||
if is_cuda and config.enable_torch_profile:
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
) as prof:
|
||||
print("Start torch profiling on decoder ...")
|
||||
with torch.profiler.record_function("decoder"):
|
||||
sam2_decoder(*torch_inputs)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
prof.export_chrome_trace("torch_image_decoder.json")
|
||||
|
||||
if config.repeats == 0:
|
||||
return
|
||||
|
||||
print(f"Start {config.repeats} runs of performance tests...")
|
||||
start = time.time()
|
||||
for _ in range(config.repeats):
|
||||
_masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
|
||||
if is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end = time.time()
|
||||
return (end - start) / config.repeats
|
||||
|
||||
|
||||
def run_test(
|
||||
args: argparse.Namespace,
|
||||
csv_writer: csv.DictWriter | None = None,
|
||||
):
|
||||
use_gpu: bool = args.use_gpu
|
||||
enable_cuda_graph: bool = args.use_cuda_graph
|
||||
repeats: int = args.repeats
|
||||
|
||||
if use_gpu:
|
||||
device_id = torch.cuda.current_device()
|
||||
device = torch.device("cuda", device_id)
|
||||
provider = "CUDAExecutionProvider"
|
||||
else:
|
||||
device_id = 0
|
||||
device = torch.device("cpu")
|
||||
enable_cuda_graph = False
|
||||
provider = "CPUExecutionProvider"
|
||||
|
||||
dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
config = TestConfig(
|
||||
model_type=args.model_type,
|
||||
onnx_path=args.onnx_path,
|
||||
sam2_dir=args.sam2_dir,
|
||||
component=args.component,
|
||||
provider=provider,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
device=device,
|
||||
use_tf32=True,
|
||||
enable_cuda_graph=enable_cuda_graph,
|
||||
dtype=dtypes[args.dtype],
|
||||
prefer_nhwc=args.prefer_nhwc,
|
||||
repeats=args.repeats,
|
||||
warm_up=args.warm_up,
|
||||
enable_nvtx_profile=args.enable_nvtx_profile,
|
||||
enable_ort_profile=args.enable_ort_profile,
|
||||
enable_torch_profile=args.enable_torch_profile,
|
||||
torch_compile_mode=args.torch_compile_mode,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
if args.engine == "ort":
|
||||
sess_options = SessionOptions()
|
||||
sess_options.intra_op_num_threads = args.intra_op_num_threads
|
||||
if config.enable_ort_profile:
|
||||
sess_options.enable_profiling = True
|
||||
sess_options.log_severity_level = 4
|
||||
sess_options.log_verbosity_level = 0
|
||||
|
||||
session = create_session(config, sess_options)
|
||||
input_dict = config.random_inputs()
|
||||
|
||||
# warm up session
|
||||
try:
|
||||
for _ in range(config.warm_up):
|
||||
_ = measure_latency(session, input_dict)
|
||||
except Exception as e:
|
||||
print(f"Failed to run {config=}. Exception: {e}")
|
||||
return
|
||||
|
||||
if config.enable_nvtx_profile:
|
||||
import nvtx # noqa: PLC0415
|
||||
from cuda import cudart # noqa: PLC0415
|
||||
|
||||
cudart.cudaProfilerStart()
|
||||
with nvtx.annotate("one_run"):
|
||||
_ = session.infer(input_dict)
|
||||
cudart.cudaProfilerStop()
|
||||
|
||||
if config.enable_ort_profile:
|
||||
session.ort_session.end_profiling()
|
||||
|
||||
if repeats == 0:
|
||||
return
|
||||
|
||||
latency_list = []
|
||||
for _ in range(repeats):
|
||||
latency = measure_latency(session, input_dict)
|
||||
latency_list.append(latency)
|
||||
average_latency = statistics.mean(latency_list)
|
||||
|
||||
del session
|
||||
else: # torch
|
||||
with torch.no_grad():
|
||||
try:
|
||||
average_latency = run_torch(config)
|
||||
except Exception as e:
|
||||
print(f"Failed to run {config=}. Exception: {e}")
|
||||
return
|
||||
|
||||
if repeats == 0:
|
||||
return
|
||||
|
||||
engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
|
||||
row = {
|
||||
"model_type": args.model_type,
|
||||
"component": args.component,
|
||||
"dtype": args.dtype,
|
||||
"use_gpu": use_gpu,
|
||||
"enable_cuda_graph": enable_cuda_graph,
|
||||
"prefer_nhwc": config.prefer_nhwc,
|
||||
"use_tf32": config.use_tf32,
|
||||
"batch_size": args.batch_size,
|
||||
"height": args.height,
|
||||
"width": args.width,
|
||||
"multi_mask_output": args.multimask_output,
|
||||
"num_labels": config.num_labels,
|
||||
"num_points": config.num_points,
|
||||
"num_masks": config.num_masks,
|
||||
"intra_op_num_threads": args.intra_op_num_threads,
|
||||
"warm_up": config.warm_up,
|
||||
"repeats": repeats,
|
||||
"enable_nvtx_profile": args.enable_nvtx_profile,
|
||||
"torch_compile_mode": args.torch_compile_mode,
|
||||
"engine": engine,
|
||||
"average_latency": average_latency,
|
||||
}
|
||||
|
||||
if csv_writer is not None:
|
||||
csv_writer.writerow(row)
|
||||
|
||||
print(f"{vars(config)}")
|
||||
print(f"{row}")
|
||||
|
||||
|
||||
def run_perf_test(args):
|
||||
features = "gpu" if args.use_gpu else "cpu"
|
||||
csv_filename = "benchmark_sam_{}_{}_{}.csv".format(
|
||||
features,
|
||||
args.engine,
|
||||
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||
)
|
||||
with open(csv_filename, mode="a", newline="") as csv_file:
|
||||
column_names = [
|
||||
"model_type",
|
||||
"component",
|
||||
"dtype",
|
||||
"use_gpu",
|
||||
"enable_cuda_graph",
|
||||
"prefer_nhwc",
|
||||
"use_tf32",
|
||||
"batch_size",
|
||||
"height",
|
||||
"width",
|
||||
"multi_mask_output",
|
||||
"num_labels",
|
||||
"num_points",
|
||||
"num_masks",
|
||||
"intra_op_num_threads",
|
||||
"warm_up",
|
||||
"repeats",
|
||||
"enable_nvtx_profile",
|
||||
"torch_compile_mode",
|
||||
"engine",
|
||||
"average_latency",
|
||||
]
|
||||
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
|
||||
csv_writer.writeheader()
|
||||
|
||||
run_test(args, csv_writer)
|
||||
|
||||
|
||||
def _parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Benchmark SMA2 for ONNX Runtime and PyTorch.")
|
||||
|
||||
parser.add_argument(
|
||||
"--component",
|
||||
required=False,
|
||||
choices=["image_encoder", "image_decoder"],
|
||||
default="image_encoder",
|
||||
help="component to benchmark. Choices are image_encoder and image_decoder.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype", required=False, choices=["fp32", "fp16", "bf16"], default="fp32", help="Data type for inference."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_gpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use GPU for inference.",
|
||||
)
|
||||
parser.set_defaults(use_gpu=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_cuda_graph",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use cuda graph in onnxruntime.",
|
||||
)
|
||||
parser.set_defaults(use_cuda_graph=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--intra_op_num_threads",
|
||||
required=False,
|
||||
type=int,
|
||||
choices=[0, 1, 2, 4, 8, 16],
|
||||
default=0,
|
||||
help="intra_op_num_threads for onnxruntime. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
required=False,
|
||||
type=int,
|
||||
default=1,
|
||||
help="batch size",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
required=False,
|
||||
type=int,
|
||||
default=1024,
|
||||
help="image height",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
required=False,
|
||||
type=int,
|
||||
default=1024,
|
||||
help="image width",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repeats",
|
||||
required=False,
|
||||
type=int,
|
||||
default=1000,
|
||||
help="number of repeats for performance test. Default is 1000.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warm_up",
|
||||
required=False,
|
||||
type=int,
|
||||
default=5,
|
||||
help="number of runs for warm up. Default is 5.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--engine",
|
||||
required=False,
|
||||
type=str,
|
||||
default="ort",
|
||||
choices=["ort", "torch"],
|
||||
help="engine for inference",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--multimask_output",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Export mask_decoder or image_decoder with multimask_output",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prefer_nhwc",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_nvtx_profile",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Enable nvtx profiling. It will add an extra run for profiling before performance test.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_ort_profile",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Enable ORT profiling.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_torch_profile",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
required=False,
|
||||
type=str,
|
||||
default="sam2_hiera_large",
|
||||
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
|
||||
help="sam2 model name",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sam2_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./segment-anything-2",
|
||||
help="The directory of segment-anything-2 git root directory",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--onnx_path",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./sam2_onnx_models/sam2_hiera_large_image_encoder.onnx",
|
||||
help="path of onnx model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--torch_compile_mode",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
|
||||
help="torch compile mode. none will disable torch compile.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = _parse_arguments()
|
||||
print(f"arguments:{args}")
|
||||
|
||||
if args.torch_compile_mode is None:
|
||||
# image decoder will fail with compile modes other than "none".
|
||||
args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"
|
||||
|
||||
if args.use_gpu:
|
||||
assert torch.cuda.is_available()
|
||||
if args.engine == "ort":
|
||||
assert "CUDAExecutionProvider" in get_available_providers()
|
||||
args.enable_torch_profile = False
|
||||
else:
|
||||
# Only support cuda profiling for now.
|
||||
assert not args.enable_nvtx_profile
|
||||
assert not args.enable_torch_profile
|
||||
|
||||
if args.enable_nvtx_profile or args.enable_torch_profile:
|
||||
run_test(args)
|
||||
else:
|
||||
run_perf_test(args)
|
||||
+270
@@ -0,0 +1,270 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from image_decoder import export_decoder_onnx, test_decoder_onnx
|
||||
from image_encoder import export_image_encoder_onnx, test_image_encoder_onnx
|
||||
from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx
|
||||
from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx
|
||||
from sam2_demo import run_demo, show_all_images
|
||||
from sam2_utils import load_sam2_model, sam2_onnx_path, setup_logger
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="Export SAM2 models to ONNX")
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
required=False,
|
||||
type=str,
|
||||
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
|
||||
default="sam2_hiera_large",
|
||||
help="The model type to export",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--components",
|
||||
required=False,
|
||||
nargs="+",
|
||||
choices=["image_encoder", "mask_decoder", "prompt_encoder", "image_decoder"],
|
||||
default=["image_encoder", "image_decoder"],
|
||||
help="Type of ONNX models to export. "
|
||||
"Note that image_decoder is a combination of prompt_encoder and mask_decoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
help="The output directory for the ONNX models",
|
||||
default="sam2_onnx_models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic_batch_axes",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Export image_encoder with dynamic batch axes",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--multimask_output",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Export mask_decoder or image_decoder with multimask_output",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable_dynamic_multimask_via_stability",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Disable mask_decoder dynamic_multimask_via_stability, and output first mask only."
|
||||
"This option will be ignored when multimask_output is True",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sam2_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default="./segment-anything-2",
|
||||
help="The directory of segment-anything-2 git repository",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Overwrite onnx model file if exists.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--demo",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Run demo with the exported ONNX models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optimize",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Optimize onnx models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype", required=False, choices=["fp32", "fp16"], default="fp32", help="Data type for inference."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_gpu",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Optimize onnx models for GPU",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamo",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Use dynamo for exporting onnx model. Only image_encoder supports dynamo right now.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
required=False,
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Print verbose information",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def optimize_sam2_model(onnx_model_path, optimized_model_path, float16: bool, use_gpu: bool):
|
||||
print(f"Optimizing {onnx_model_path} to {optimized_model_path} with float16={float16} and use_gpu={use_gpu}...")
|
||||
|
||||
# Import from source directory.
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.insert(0, transformers_dir)
|
||||
from optimizer import optimize_model # noqa: PLC0415
|
||||
|
||||
optimized_model = optimize_model(onnx_model_path, model_type="sam2", opt_level=1, use_gpu=use_gpu)
|
||||
if float16:
|
||||
optimized_model.convert_float_to_float16(keep_io_types=False)
|
||||
optimized_model.save_model_to_file(optimized_model_path)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
sam2_model = load_sam2_model(args.sam2_dir, args.model_type, device="cpu")
|
||||
|
||||
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for component in args.components:
|
||||
onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output)
|
||||
if component == "image_encoder":
|
||||
if args.overwrite or not os.path.exists(onnx_model_path):
|
||||
export_image_encoder_onnx(
|
||||
sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose, args.dynamo
|
||||
)
|
||||
test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=args.dynamic_batch_axes)
|
||||
|
||||
elif component == "mask_decoder":
|
||||
if args.overwrite or not os.path.exists(onnx_model_path):
|
||||
export_mask_decoder_onnx(
|
||||
sam2_model,
|
||||
onnx_model_path,
|
||||
args.multimask_output,
|
||||
not args.disable_dynamic_multimask_via_stability,
|
||||
args.verbose,
|
||||
)
|
||||
test_mask_decoder_onnx(
|
||||
sam2_model,
|
||||
onnx_model_path,
|
||||
args.multimask_output,
|
||||
not args.disable_dynamic_multimask_via_stability,
|
||||
)
|
||||
elif component == "prompt_encoder":
|
||||
if args.overwrite or not os.path.exists(onnx_model_path):
|
||||
export_prompt_encoder_onnx(sam2_model, onnx_model_path)
|
||||
test_prompt_encoder_onnx(sam2_model, onnx_model_path)
|
||||
else:
|
||||
assert component == "image_decoder"
|
||||
if args.overwrite or not os.path.exists(onnx_model_path):
|
||||
export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
|
||||
test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
|
||||
|
||||
suffix = ""
|
||||
convert_to_fp16 = args.dtype == "fp16"
|
||||
if args.optimize:
|
||||
suffix = f"_{args.dtype}_" + ("gpu" if args.use_gpu else "cpu")
|
||||
for component in args.components:
|
||||
onnx_model_path = sam2_onnx_path(args.output_dir, args.model_type, component, args.multimask_output)
|
||||
optimized_model_path = sam2_onnx_path(
|
||||
args.output_dir, args.model_type, component, args.multimask_output, suffix
|
||||
)
|
||||
optimize_sam2_model(onnx_model_path, optimized_model_path, convert_to_fp16, args.use_gpu)
|
||||
|
||||
if args.demo:
|
||||
# Export required ONNX models for demo if not already exported.
|
||||
image_encoder_onnx_path = sam2_onnx_path(
|
||||
args.output_dir, args.model_type, "image_encoder", args.multimask_output
|
||||
)
|
||||
if not os.path.exists(image_encoder_onnx_path):
|
||||
export_image_encoder_onnx(sam2_model, image_encoder_onnx_path, args.dynamic_batch_axes, args.verbose)
|
||||
|
||||
image_decoder_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", False)
|
||||
if not os.path.exists(image_decoder_onnx_path):
|
||||
export_decoder_onnx(sam2_model, image_decoder_onnx_path, False)
|
||||
|
||||
image_decoder_multi_onnx_path = sam2_onnx_path(args.output_dir, args.model_type, "image_decoder", True)
|
||||
if not os.path.exists(image_decoder_multi_onnx_path):
|
||||
export_decoder_onnx(sam2_model, image_decoder_multi_onnx_path, True)
|
||||
|
||||
dtype = torch.float32 if args.dtype == "fp32" else torch.float16
|
||||
if suffix:
|
||||
optimized_image_encoder_onnx_path = image_encoder_onnx_path.replace(".onnx", f"{suffix}.onnx")
|
||||
if not os.path.exists(optimized_image_encoder_onnx_path):
|
||||
optimize_sam2_model(
|
||||
image_encoder_onnx_path, optimized_image_encoder_onnx_path, convert_to_fp16, args.use_gpu
|
||||
)
|
||||
|
||||
optimized_image_decoder_onnx_path = image_decoder_onnx_path.replace(".onnx", f"{suffix}.onnx")
|
||||
if not os.path.exists(optimized_image_decoder_onnx_path):
|
||||
optimize_sam2_model(
|
||||
image_decoder_onnx_path, optimized_image_decoder_onnx_path, convert_to_fp16, args.use_gpu
|
||||
)
|
||||
|
||||
optimized_image_decoder_multi_onnx_path = image_decoder_multi_onnx_path.replace(".onnx", f"{suffix}.onnx")
|
||||
if not os.path.exists(optimized_image_decoder_multi_onnx_path):
|
||||
optimize_sam2_model(
|
||||
image_decoder_multi_onnx_path,
|
||||
optimized_image_decoder_multi_onnx_path,
|
||||
convert_to_fp16,
|
||||
args.use_gpu,
|
||||
)
|
||||
|
||||
# Use optimized models to run demo.
|
||||
image_encoder_onnx_path = optimized_image_encoder_onnx_path
|
||||
image_decoder_onnx_path = optimized_image_decoder_onnx_path
|
||||
image_decoder_multi_onnx_path = optimized_image_decoder_multi_onnx_path
|
||||
|
||||
ort_image_files = run_demo(
|
||||
args.sam2_dir,
|
||||
args.model_type,
|
||||
engine="ort",
|
||||
dtype=dtype,
|
||||
image_encoder_onnx_path=image_encoder_onnx_path,
|
||||
image_decoder_onnx_path=image_decoder_onnx_path,
|
||||
image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
|
||||
use_gpu=args.use_gpu,
|
||||
)
|
||||
print("demo output files for ONNX Runtime:", ort_image_files)
|
||||
|
||||
# Get results from torch engine to compare.
|
||||
torch_image_files = run_demo(args.sam2_dir, args.model_type, engine="torch", dtype=dtype, use_gpu=args.use_gpu)
|
||||
print("demo output files for PyTorch:", torch_image_files)
|
||||
|
||||
show_all_images(ort_image_files, torch_image_files, suffix)
|
||||
print(f"Combined demo output: sam2_demo{suffix}.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logger(verbose=False)
|
||||
with torch.no_grad():
|
||||
main()
|
||||
+272
@@ -0,0 +1,272 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
|
||||
from mask_decoder import SAM2MaskDecoder
|
||||
from prompt_encoder import SAM2PromptEncoder
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2_utils import compare_tensors_with_tolerance
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAM2ImageDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: SAM2Base,
|
||||
multimask_output: bool,
|
||||
dynamic_multimask_via_stability: bool = True,
|
||||
return_logits: bool = False,
|
||||
mask_threshold: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.prompt_encoder = SAM2PromptEncoder(sam_model)
|
||||
self.mask_decoder = SAM2MaskDecoder(sam_model, multimask_output, dynamic_multimask_via_stability)
|
||||
self.return_logits = return_logits
|
||||
self.mask_threshold = mask_threshold
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_features_0: torch.Tensor,
|
||||
image_features_1: torch.Tensor,
|
||||
image_embeddings: torch.Tensor,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
input_masks: torch.Tensor,
|
||||
has_input_masks: torch.Tensor,
|
||||
original_image_size: torch.Tensor,
|
||||
enable_nvtx_profile: bool = False,
|
||||
):
|
||||
"""
|
||||
Decode masks from image features and prompts. Batched images are not supported. H=W=1024.
|
||||
|
||||
Args:
|
||||
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
|
||||
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
|
||||
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
|
||||
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
|
||||
coordinate in (x, y) format of the P input points in image of size 1024x1024.
|
||||
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
|
||||
positive (foreground), 0 means negative (background), -1 means padding,
|
||||
2 (box left upper corner), 3 (box right bottom corner).
|
||||
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
|
||||
Typically coming from a previous iteration.
|
||||
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
|
||||
original_image_size(torch.Tensor): [2]. original image size H_o, W_o.
|
||||
enable_nvtx_profile (bool): enable NVTX profiling.
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size.
|
||||
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
|
||||
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
|
||||
"""
|
||||
nvtx_helper = None
|
||||
if enable_nvtx_profile:
|
||||
from nvtx_helper import NvtxHelper # noqa: PLC0415
|
||||
|
||||
nvtx_helper = NvtxHelper(["prompt_encoder", "mask_decoder", "post_process"])
|
||||
|
||||
if nvtx_helper is not None:
|
||||
nvtx_helper.start_profile("prompt_encoder", color="blue")
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
if nvtx_helper is not None:
|
||||
nvtx_helper.stop_profile("prompt_encoder")
|
||||
nvtx_helper.start_profile("mask_decoder", color="red")
|
||||
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings
|
||||
)
|
||||
|
||||
if nvtx_helper is not None:
|
||||
nvtx_helper.stop_profile("mask_decoder")
|
||||
nvtx_helper.start_profile("post_process", color="green")
|
||||
|
||||
# Interpolate the low resolution masks back to the original image size.
|
||||
masks = F.interpolate(
|
||||
low_res_masks,
|
||||
(original_image_size[0], original_image_size[1]),
|
||||
mode="bilinear",
|
||||
align_corners=False, # Note that align_corners=True has less mismatches during comparing ORT and PyTorch.
|
||||
)
|
||||
|
||||
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
||||
if not self.return_logits:
|
||||
masks = masks > self.mask_threshold
|
||||
|
||||
if nvtx_helper is not None:
|
||||
nvtx_helper.stop_profile("post_process")
|
||||
nvtx_helper.print_latency()
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
|
||||
def export_decoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
multimask_output: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
batch_size = 1
|
||||
image = random_sam2_input_image(batch_size)
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
|
||||
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
||||
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
||||
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
||||
|
||||
sam2_decoder = SAM2ImageDecoder(
|
||||
sam2_model,
|
||||
multimask_output=multimask_output,
|
||||
dynamic_multimask_via_stability=True,
|
||||
).cpu()
|
||||
|
||||
num_labels = 2
|
||||
num_points = 3
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
original_image_size = torch.tensor([1200, 1800], dtype=torch.int32)
|
||||
|
||||
example_inputs = (
|
||||
image_features_0,
|
||||
image_features_1,
|
||||
image_embeddings,
|
||||
point_coords,
|
||||
point_labels,
|
||||
input_masks,
|
||||
has_input_masks,
|
||||
original_image_size,
|
||||
)
|
||||
|
||||
logger.info("point_coords.shape: %s", point_coords.shape)
|
||||
logger.info("point_labels.shape: %s", point_labels.shape)
|
||||
logger.info("input_masks.shape: %s", input_masks.shape)
|
||||
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
|
||||
logger.info("original_image_size.shape: %s", original_image_size.shape)
|
||||
|
||||
if verbose:
|
||||
masks, iou_predictions, low_res_masks = sam2_decoder(*example_inputs)
|
||||
logger.info("masks.shape: %s", masks.shape)
|
||||
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
|
||||
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
|
||||
|
||||
input_names = [
|
||||
"image_features_0",
|
||||
"image_features_1",
|
||||
"image_embeddings",
|
||||
"point_coords",
|
||||
"point_labels",
|
||||
"input_masks",
|
||||
"has_input_masks",
|
||||
"original_image_size",
|
||||
]
|
||||
|
||||
output_names = ["masks", "iou_predictions", "low_res_masks"]
|
||||
|
||||
dynamic_axes = {
|
||||
"point_coords": {0: "num_labels", 1: "num_points"},
|
||||
"point_labels": {0: "num_labels", 1: "num_points"},
|
||||
"input_masks": {0: "num_labels"},
|
||||
"has_input_masks": {0: "num_labels"},
|
||||
"masks": {0: "num_labels", 2: "original_image_height", 3: "original_image_width"},
|
||||
"low_res_masks": {0: "num_labels"},
|
||||
"iou_predictions": {0: "num_labels"},
|
||||
}
|
||||
|
||||
with warnings.catch_warnings():
|
||||
if not verbose:
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
torch.onnx.export(
|
||||
sam2_decoder,
|
||||
example_inputs,
|
||||
onnx_model_path,
|
||||
export_params=True,
|
||||
opset_version=16,
|
||||
do_constant_folding=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
|
||||
logger.info("decoder onnx model saved to %s", onnx_model_path)
|
||||
|
||||
|
||||
def test_decoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
multimask_output=False,
|
||||
):
|
||||
batch_size = 1
|
||||
image = random_sam2_input_image(batch_size)
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
|
||||
sam2_image_decoder = SAM2ImageDecoder(
|
||||
sam2_model,
|
||||
multimask_output=multimask_output,
|
||||
dynamic_multimask_via_stability=True,
|
||||
).cpu()
|
||||
|
||||
num_labels = 1
|
||||
num_points = 5
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.zeros(1, dtype=torch.float)
|
||||
original_image_size = torch.tensor([1500, 1500], dtype=torch.int32)
|
||||
|
||||
example_inputs = (
|
||||
image_features_0,
|
||||
image_features_1,
|
||||
image_embeddings,
|
||||
point_coords,
|
||||
point_labels,
|
||||
input_masks,
|
||||
has_input_masks,
|
||||
original_image_size,
|
||||
)
|
||||
|
||||
masks, iou_predictions, low_res_masks = sam2_image_decoder(*example_inputs)
|
||||
|
||||
import onnxruntime # noqa: PLC0415
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
|
||||
|
||||
model_inputs = ort_session.get_inputs()
|
||||
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
logger.info("input_names: %s", input_names)
|
||||
|
||||
model_outputs = ort_session.get_outputs()
|
||||
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
logger.info("output_names: %s", output_names)
|
||||
inputs = {model_inputs[i].name: example_inputs[i].numpy() for i in range(len(model_inputs))}
|
||||
outputs = ort_session.run(output_names, inputs)
|
||||
|
||||
for i, output_name in enumerate(output_names):
|
||||
logger.info(f"{output_name}.shape: %s", outputs[i].shape)
|
||||
|
||||
ort_masks, ort_iou_predictions, ort_low_res_masks = outputs
|
||||
if (
|
||||
compare_tensors_with_tolerance("masks", masks.float(), torch.tensor(ort_masks).float())
|
||||
and compare_tensors_with_tolerance("iou_predictions", iou_predictions, torch.tensor(ort_iou_predictions))
|
||||
and compare_tensors_with_tolerance("low_res_masks", low_res_masks, torch.tensor(ort_low_res_masks))
|
||||
):
|
||||
print("onnx model has been verified:", onnx_model_path)
|
||||
else:
|
||||
print("onnx model verification failed:", onnx_model_path)
|
||||
+236
@@ -0,0 +1,236 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
|
||||
from torch import nn
|
||||
|
||||
import onnxruntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAM2ImageEncoder(nn.Module):
|
||||
def __init__(self, sam_model: SAM2Base) -> None:
|
||||
super().__init__()
|
||||
self.model = sam_model
|
||||
self.image_encoder = sam_model.image_encoder
|
||||
self.no_mem_embed = sam_model.no_mem_embed
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
enable_nvtx_profile: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Encodes images into features.
|
||||
|
||||
Only supports H=W=1024. If you want to use different image sizes like 512x512,
|
||||
see https://github.com/facebookresearch/segment-anything-2/issues/138.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
|
||||
enable_nvtx_profile (bool): enable NVTX profiling.
|
||||
|
||||
Returns:
|
||||
image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
|
||||
image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
|
||||
image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
|
||||
"""
|
||||
nvtx_helper = None
|
||||
if enable_nvtx_profile:
|
||||
from nvtx_helper import NvtxHelper # noqa: PLC0415
|
||||
|
||||
nvtx_helper = NvtxHelper(["image_encoder", "post_process"])
|
||||
|
||||
if nvtx_helper is not None:
|
||||
nvtx_helper.start_profile("image_encoder")
|
||||
|
||||
backbone_out = self.image_encoder(image)
|
||||
|
||||
if nvtx_helper is not None:
|
||||
nvtx_helper.stop_profile("image_encoder")
|
||||
nvtx_helper.start_profile("post_process")
|
||||
|
||||
# precompute projected level 0 and level 1 features in SAM decoder
|
||||
# to avoid running it again on every SAM click
|
||||
backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
||||
backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
||||
|
||||
# Prepare and flatten visual features.
|
||||
feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
|
||||
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
|
||||
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
||||
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
# TODO: we should avoid this transpose since it will be transposed back to NCHW later.
|
||||
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
||||
|
||||
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
|
||||
|
||||
feats = [
|
||||
feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
|
||||
for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1], strict=False)
|
||||
][::-1]
|
||||
|
||||
if nvtx_helper is not None:
|
||||
nvtx_helper.stop_profile("post_process")
|
||||
nvtx_helper.print_latency()
|
||||
|
||||
return feats[0], feats[1], feats[2]
|
||||
|
||||
|
||||
def export_image_encoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
dynamic_batch_axes: bool = False,
|
||||
verbose: bool = False,
|
||||
dynamo: bool = False,
|
||||
clear_dynamo_metadata: bool = False,
|
||||
):
|
||||
image = random_sam2_input_image()
|
||||
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
logger.info("image.shape: %s", image.shape)
|
||||
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
||||
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
||||
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
||||
|
||||
dynamic_axes = None
|
||||
if dynamic_batch_axes:
|
||||
dynamic_axes = {
|
||||
"image": {0: "batch_size"},
|
||||
"image_features_0": {0: "batch_size"},
|
||||
"image_features_1": {0: "batch_size"},
|
||||
"image_embeddings": {0: "batch_size"},
|
||||
}
|
||||
|
||||
with warnings.catch_warnings():
|
||||
if not verbose:
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
if not dynamo:
|
||||
torch.onnx.export(
|
||||
sam2_encoder,
|
||||
image,
|
||||
onnx_model_path,
|
||||
export_params=True,
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
input_names=["image"],
|
||||
output_names=["image_features_0", "image_features_1", "image_embeddings"],
|
||||
dynamic_axes=dynamic_axes,
|
||||
)
|
||||
else:
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
ep = torch.export.export(
|
||||
sam2_encoder,
|
||||
args=(image,),
|
||||
strict=False,
|
||||
dynamic_shapes=[
|
||||
{0: torch.export.Dim.AUTO},
|
||||
],
|
||||
)
|
||||
|
||||
onnx_program = torch.onnx.export(
|
||||
ep,
|
||||
(),
|
||||
opset_version=17,
|
||||
input_names=["image"],
|
||||
output_names=["image_features_0", "image_features_1", "image_embeddings"],
|
||||
dynamo=True,
|
||||
)
|
||||
onnx_program.optimize()
|
||||
onnx_program.save(onnx_model_path + ".dynamo.onnx", external_data=False)
|
||||
import onnx # noqa: PLC0415
|
||||
|
||||
from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper # noqa: PLC0415
|
||||
|
||||
onnx_model = onnx.load_model(onnx_model_path + ".dynamo.onnx", load_external_data=True)
|
||||
if dynamic_batch_axes:
|
||||
# Fix labels of dynamic axes since they can't be specified during Dynamo export currently
|
||||
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "batch_size"
|
||||
for i in range(3):
|
||||
onnx_model.graph.output[i].type.tensor_type.shape.dim[0].dim_param = "batch_size"
|
||||
|
||||
onnx_model_helper = DynamoOnnxHelper(onnx_model)
|
||||
onnx_model_helper.convert_constants_to_initializers()
|
||||
if clear_dynamo_metadata:
|
||||
onnx_model_helper.clear_metadata()
|
||||
|
||||
import os # noqa: PLC0415
|
||||
|
||||
if os.path.exists(onnx_model_path):
|
||||
os.remove(onnx_model_path)
|
||||
if os.path.exists(onnx_model_path + ".data"):
|
||||
os.remove(onnx_model_path + ".data")
|
||||
onnx_model_helper.model.save_model_to_file(
|
||||
onnx_model_path, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True
|
||||
)
|
||||
|
||||
print("encoder onnx model saved to", onnx_model_path)
|
||||
|
||||
|
||||
def test_image_encoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
dynamic_batch_axes=False,
|
||||
):
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
|
||||
|
||||
model_inputs = ort_session.get_inputs()
|
||||
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
logger.info("input_names: %s", input_names)
|
||||
|
||||
model_outputs = ort_session.get_outputs()
|
||||
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
logger.info("output_names: %s", output_names)
|
||||
|
||||
batch_sizes = [1, 2] if dynamic_batch_axes else [1]
|
||||
for batch_size in batch_sizes:
|
||||
image = random_sam2_input_image(batch_size)
|
||||
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
|
||||
|
||||
logger.info("image.shape: %s", image.shape)
|
||||
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
||||
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
||||
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
||||
|
||||
outputs = ort_session.run(output_names, {"image": image.numpy()})
|
||||
for i, output_name in enumerate(output_names):
|
||||
logger.info("output %s shape %s", output_name, outputs[i].shape)
|
||||
ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
|
||||
|
||||
# ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
|
||||
if (
|
||||
compare_tensors_with_tolerance(
|
||||
"image_features_0",
|
||||
image_features_0,
|
||||
torch.tensor(ort_image_features_0),
|
||||
mismatch_percentage_tolerance=1,
|
||||
)
|
||||
and compare_tensors_with_tolerance(
|
||||
"image_features_1",
|
||||
image_features_1,
|
||||
torch.tensor(ort_image_features_1),
|
||||
mismatch_percentage_tolerance=1,
|
||||
)
|
||||
and compare_tensors_with_tolerance(
|
||||
"image_embeddings",
|
||||
image_embeddings,
|
||||
torch.tensor(ort_image_embeddings),
|
||||
mismatch_percentage_tolerance=1,
|
||||
)
|
||||
):
|
||||
print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
|
||||
else:
|
||||
print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")
|
||||
+208
@@ -0,0 +1,208 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
|
||||
from prompt_encoder import SAM2PromptEncoder
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAM2MaskDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: SAM2Base,
|
||||
multimask_output: bool,
|
||||
dynamic_multimask_via_stability: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.mask_decoder = sam_model.sam_mask_decoder
|
||||
self.prompt_encoder = sam_model.sam_prompt_encoder
|
||||
self.model = sam_model
|
||||
self.multimask_output = multimask_output
|
||||
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
image_features_0: torch.Tensor,
|
||||
image_features_1: torch.Tensor,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_embeddings: torch.Tensor,
|
||||
dense_embeddings: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Decode masks from image and prompt embeddings. Only support H=W=1024.
|
||||
|
||||
Args:
|
||||
image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
|
||||
image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
|
||||
image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
|
||||
image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding.
|
||||
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
|
||||
dense_embeddings (torch.Tensor): [L, 256, H/16, W/16]. embedding for input masks.
|
||||
|
||||
Returns:
|
||||
low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
|
||||
iou_predictions (torch.Tensor): [1, M]. scores for M masks.
|
||||
"""
|
||||
low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
repeat_image=sparse_embeddings.shape[0] > 1, # batch mode
|
||||
high_res_features=[image_features_0, image_features_1],
|
||||
)
|
||||
|
||||
if self.multimask_output:
|
||||
low_res_masks = low_res_masks[:, 1:, :, :]
|
||||
iou_predictions = iou_predictions[:, 1:]
|
||||
elif self.dynamic_multimask_via_stability:
|
||||
# When outputting a single mask, if the stability score from the current single-mask
|
||||
# output (based on output token 0) falls below a threshold, we instead select from
|
||||
# multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score.
|
||||
low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(
|
||||
low_res_masks, iou_predictions
|
||||
)
|
||||
else:
|
||||
low_res_masks = low_res_masks[:, 0:1, :, :]
|
||||
iou_predictions = iou_predictions[:, 0:1]
|
||||
|
||||
return low_res_masks, iou_predictions
|
||||
|
||||
|
||||
def export_mask_decoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
multimask_output: bool,
|
||||
dynamic_multimask_via_stability: bool = True,
|
||||
verbose=False,
|
||||
):
|
||||
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
||||
|
||||
image = random_sam2_input_image()
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
logger.info("image_features_0.shape: %s", image_features_0.shape)
|
||||
logger.info("image_features_1.shape: %s", image_features_1.shape)
|
||||
logger.info("image_embeddings.shape: %s", image_embeddings.shape)
|
||||
|
||||
# encode an random prompt
|
||||
num_labels = 2
|
||||
num_points = 3
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
|
||||
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
|
||||
logger.info("image_pe.shape: %s", image_pe.shape)
|
||||
|
||||
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
|
||||
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
|
||||
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
|
||||
logger.info("low_res_masks.shape: %s", low_res_masks.shape)
|
||||
logger.info("iou_predictions.shape: %s", iou_predictions.shape)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
if not verbose:
|
||||
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
torch.onnx.export(
|
||||
sam2_mask_decoder,
|
||||
inputs,
|
||||
onnx_model_path,
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=[
|
||||
"image_features_0",
|
||||
"image_features_1",
|
||||
"image_embeddings",
|
||||
"image_pe",
|
||||
"sparse_embeddings",
|
||||
"dense_embeddings",
|
||||
],
|
||||
output_names=["low_res_masks", "iou_predictions"],
|
||||
dynamic_axes={
|
||||
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
|
||||
"dense_embeddings": {0: "num_labels"},
|
||||
"low_res_masks": {0: "num_labels"},
|
||||
"iou_predictions": {0: "num_labels"},
|
||||
},
|
||||
)
|
||||
|
||||
print("mask decoder onnx model saved to", onnx_model_path)
|
||||
|
||||
|
||||
def test_mask_decoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
multimask_output: bool,
|
||||
dynamic_multimask_via_stability: bool,
|
||||
):
|
||||
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
||||
|
||||
image = random_sam2_input_image()
|
||||
sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
|
||||
image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
|
||||
|
||||
num_labels = 1
|
||||
num_points = 5
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
|
||||
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
|
||||
inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
|
||||
low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
|
||||
|
||||
import onnxruntime # noqa: PLC0415
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
|
||||
|
||||
model_inputs = ort_session.get_inputs()
|
||||
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
logger.info("input_names: %s", input_names)
|
||||
|
||||
model_outputs = ort_session.get_outputs()
|
||||
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
logger.info("output_names: %s", output_names)
|
||||
|
||||
outputs = ort_session.run(
|
||||
output_names,
|
||||
{
|
||||
"image_features_0": image_features_0.numpy(),
|
||||
"image_features_1": image_features_1.numpy(),
|
||||
"image_embeddings": image_embeddings.numpy(),
|
||||
"image_pe": image_pe.numpy(),
|
||||
"sparse_embeddings": sparse_embeddings.numpy(),
|
||||
"dense_embeddings": dense_embeddings.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
for i, output_name in enumerate(output_names):
|
||||
logger.info("output %s shape: %s", output_name, outputs[i].shape)
|
||||
|
||||
ort_low_res_masks, ort_iou_predictions = outputs
|
||||
torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4)
|
||||
torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4)
|
||||
print(f"onnx model has been verified: {onnx_model_path}")
|
||||
+33
@@ -0,0 +1,33 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import nvtx
|
||||
from cuda import cudart
|
||||
|
||||
|
||||
class NvtxHelper:
|
||||
def __init__(self, stages):
|
||||
self.stages = stages
|
||||
self.events = {}
|
||||
for stage in stages:
|
||||
for marker in ["start", "stop"]:
|
||||
self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
|
||||
self.markers = {}
|
||||
|
||||
def start_profile(self, stage, color="blue"):
|
||||
self.markers[stage] = nvtx.start_range(message=stage, color=color)
|
||||
event_name = stage + "-start"
|
||||
if event_name in self.events:
|
||||
cudart.cudaEventRecord(self.events[event_name], 0)
|
||||
|
||||
def stop_profile(self, stage):
|
||||
event_name = stage + "-stop"
|
||||
if event_name in self.events:
|
||||
cudart.cudaEventRecord(self.events[event_name], 0)
|
||||
nvtx.end_range(self.markers[stage])
|
||||
|
||||
def print_latency(self):
|
||||
for stage in self.stages:
|
||||
latency = cudart.cudaEventElapsedTime(self.events[f"{stage}-start"], self.events[f"{stage}-stop"])[1]
|
||||
print(f"{stage}: {latency:.2f} ms")
|
||||
+189
@@ -0,0 +1,189 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2_utils import compare_tensors_with_tolerance
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SAM2PromptEncoder(nn.Module):
|
||||
def __init__(self, sam_model: SAM2Base):
|
||||
super().__init__()
|
||||
self.prompt_encoder = sam_model.sam_prompt_encoder
|
||||
self.model = sam_model
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
point_coords: torch.Tensor,
|
||||
point_labels: torch.Tensor,
|
||||
input_masks: torch.Tensor,
|
||||
has_input_masks: torch.Tensor,
|
||||
):
|
||||
"""Encode prompts.
|
||||
|
||||
Args:
|
||||
point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
|
||||
coordinate in (x, y) format of the P input points in image of size 1024x1024.
|
||||
point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
|
||||
positive (foreground), 0 means negative (background), -1 means padding,
|
||||
2 (box left upper corner), 3 (box right bottom corner).
|
||||
input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
|
||||
Typically coming from a previous iteration.
|
||||
has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
|
||||
Returns:
|
||||
sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
|
||||
dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
|
||||
image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
|
||||
"""
|
||||
sparse_embeddings = self._embed_points(point_coords, point_labels)
|
||||
dense_embeddings = self._embed_masks(input_masks, has_input_masks)
|
||||
image_pe = self.prompt_encoder.get_dense_pe()
|
||||
|
||||
return sparse_embeddings, dense_embeddings, image_pe
|
||||
|
||||
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
||||
point_coords = point_coords + 0.5
|
||||
|
||||
padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
|
||||
padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
|
||||
point_coords = torch.cat([point_coords, padding_point], dim=1)
|
||||
point_labels = torch.cat([point_labels, padding_label], dim=1)
|
||||
|
||||
# Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
|
||||
point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
|
||||
point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
|
||||
|
||||
point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
||||
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
||||
|
||||
point_embedding = point_embedding * (point_labels != -1)
|
||||
point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
|
||||
|
||||
for i in range(self.prompt_encoder.num_point_embeddings):
|
||||
point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
|
||||
|
||||
return point_embedding
|
||||
|
||||
def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
|
||||
mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
|
||||
no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
||||
logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
|
||||
mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
|
||||
logger.info("mask_embedding.shape: %s", mask_embedding.shape)
|
||||
return mask_embedding
|
||||
|
||||
|
||||
def export_prompt_encoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
):
|
||||
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
||||
|
||||
num_labels = 2
|
||||
num_points = 3
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
logger.info("point_coords.shape: %s", point_coords.shape)
|
||||
logger.info("point_labels.shape: %s", point_labels.shape)
|
||||
logger.info("input_masks.shape: %s", input_masks.shape)
|
||||
logger.info("has_input_masks.shape: %s", has_input_masks.shape)
|
||||
|
||||
logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
|
||||
logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
|
||||
logger.info("image_pe.shape: %s", image_pe.shape)
|
||||
|
||||
torch.onnx.export(
|
||||
sam2_prompt_encoder,
|
||||
(point_coords, point_labels, input_masks, has_input_masks),
|
||||
onnx_model_path,
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
|
||||
output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
|
||||
dynamic_axes={
|
||||
"point_coords": {0: "num_labels", 1: "num_points"},
|
||||
"point_labels": {0: "num_labels", 1: "num_points"},
|
||||
"input_masks": {0: "num_labels"},
|
||||
"sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
|
||||
"dense_embeddings": {0: "num_labels"},
|
||||
},
|
||||
)
|
||||
|
||||
print("prompt encoder onnx model saved to ", onnx_model_path)
|
||||
|
||||
|
||||
def test_prompt_encoder_onnx(
|
||||
sam2_model: SAM2Base,
|
||||
onnx_model_path: str,
|
||||
):
|
||||
sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
|
||||
|
||||
num_labels = 1
|
||||
num_points = 5
|
||||
point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
|
||||
point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
|
||||
input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
|
||||
has_input_masks = torch.ones(1, dtype=torch.float)
|
||||
|
||||
sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
|
||||
point_coords, point_labels, input_masks, has_input_masks
|
||||
)
|
||||
|
||||
import onnxruntime # noqa: PLC0415
|
||||
|
||||
ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
|
||||
|
||||
model_inputs = ort_session.get_inputs()
|
||||
input_names = [model_inputs[i].name for i in range(len(model_inputs))]
|
||||
logger.info("input_names: %s", input_names)
|
||||
|
||||
model_outputs = ort_session.get_outputs()
|
||||
output_names = [model_outputs[i].name for i in range(len(model_outputs))]
|
||||
logger.info("output_names: %s", output_names)
|
||||
|
||||
outputs = ort_session.run(
|
||||
output_names,
|
||||
{
|
||||
"point_coords": point_coords.numpy(),
|
||||
"point_labels": point_labels.numpy(),
|
||||
"input_masks": input_masks.numpy(),
|
||||
"has_input_masks": has_input_masks.numpy(),
|
||||
},
|
||||
)
|
||||
|
||||
for i, output_name in enumerate(output_names):
|
||||
logger.info("output %s shape: %s", output_name, outputs[i].shape)
|
||||
|
||||
ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
|
||||
if (
|
||||
compare_tensors_with_tolerance(
|
||||
"sparse_embeddings",
|
||||
sparse_embeddings,
|
||||
torch.tensor(ort_sparse_embeddings),
|
||||
mismatch_percentage_tolerance=0.2,
|
||||
)
|
||||
and compare_tensors_with_tolerance(
|
||||
"dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
|
||||
)
|
||||
and compare_tensors_with_tolerance(
|
||||
"image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
|
||||
)
|
||||
):
|
||||
print(f"onnx model has been verified: {onnx_model_path}")
|
||||
else:
|
||||
print(f"onnx model verification failed: {onnx_model_path}")
|
||||
+321
@@ -0,0 +1,321 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.patches import Rectangle
|
||||
from PIL import Image
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor
|
||||
from sam2_utils import load_sam2_model
|
||||
|
||||
import onnxruntime
|
||||
|
||||
|
||||
def show_mask(mask, ax, random_color=False, borders=True):
|
||||
if random_color:
|
||||
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
||||
else:
|
||||
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
||||
h, w = mask.shape[-2:]
|
||||
mask = mask.astype(np.uint8)
|
||||
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
||||
if borders:
|
||||
import cv2 # noqa: PLC0415
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
||||
# Try to smooth contours
|
||||
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
||||
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
||||
ax.imshow(mask_image)
|
||||
|
||||
|
||||
def show_points(coords, labels, ax, marker_size=375):
|
||||
pos_points = coords[labels == 1]
|
||||
neg_points = coords[labels == 0]
|
||||
ax.scatter(
|
||||
pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
|
||||
)
|
||||
ax.scatter(
|
||||
neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
|
||||
)
|
||||
|
||||
|
||||
def show_box(box, ax):
|
||||
x0, y0 = box[0], box[1]
|
||||
w, h = box[2] - box[0], box[3] - box[1]
|
||||
ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
|
||||
|
||||
|
||||
def show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
point_coords=None,
|
||||
box_coords=None,
|
||||
input_labels=None,
|
||||
borders=True,
|
||||
output_image_file_prefix=None,
|
||||
image_files=None,
|
||||
):
|
||||
for i, (mask, score) in enumerate(zip(masks, scores, strict=False)):
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(image)
|
||||
show_mask(mask, plt.gca(), borders=borders)
|
||||
if point_coords is not None:
|
||||
assert input_labels is not None
|
||||
show_points(point_coords, input_labels, plt.gca())
|
||||
|
||||
if box_coords is not None:
|
||||
show_box(box_coords, plt.gca())
|
||||
|
||||
if len(scores) > 1:
|
||||
plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
|
||||
|
||||
plt.axis("off")
|
||||
if output_image_file_prefix:
|
||||
filename = f"{output_image_file_prefix}_{i}.png"
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0)
|
||||
if isinstance(image_files, list):
|
||||
image_files.append(filename)
|
||||
plt.show(block=False)
|
||||
plt.close()
|
||||
|
||||
|
||||
def get_predictor(
|
||||
sam2_dir: str,
|
||||
device: str | torch.device,
|
||||
dtype: torch.dtype,
|
||||
model_type="sam2_hiera_large",
|
||||
engine="torch",
|
||||
image_encoder_onnx_path: str = "",
|
||||
image_decoder_onnx_path: str = "",
|
||||
image_decoder_multi_onnx_path: str = "",
|
||||
provider: str = "CUDAExecutionProvider",
|
||||
):
|
||||
sam2_model = load_sam2_model(sam2_dir, model_type, device=device)
|
||||
if engine == "torch":
|
||||
predictor = SAM2ImagePredictor(sam2_model)
|
||||
else:
|
||||
predictor = SAM2ImageOnnxPredictor(
|
||||
sam2_model,
|
||||
image_encoder_onnx_path=image_encoder_onnx_path,
|
||||
image_decoder_onnx_path=image_decoder_onnx_path,
|
||||
image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
|
||||
provider=provider,
|
||||
device=device,
|
||||
onnx_dtype=dtype,
|
||||
)
|
||||
return predictor
|
||||
|
||||
|
||||
def run_demo(
|
||||
sam2_dir: str,
|
||||
model_type: str = "sam2_hiera_large",
|
||||
engine: str = "torch",
|
||||
dtype: torch.dtype = torch.float32,
|
||||
image_encoder_onnx_path: str = "",
|
||||
image_decoder_onnx_path: str = "",
|
||||
image_decoder_multi_onnx_path: str = "",
|
||||
use_gpu: bool = True,
|
||||
enable_batch: bool = False,
|
||||
):
|
||||
if use_gpu:
|
||||
assert torch.cuda.is_available()
|
||||
assert "CUDAExecutionProvider" in onnxruntime.get_available_providers()
|
||||
provider = "CUDAExecutionProvider"
|
||||
else:
|
||||
provider = "CPUExecutionProvider"
|
||||
|
||||
device = torch.device("cuda" if use_gpu else "cpu")
|
||||
|
||||
if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8:
|
||||
# Turn on tfloat32 for Ampere GPUs.
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
np.random.seed(3)
|
||||
image = Image.open("truck.jpg")
|
||||
image = np.array(image.convert("RGB"))
|
||||
|
||||
predictor = get_predictor(
|
||||
sam2_dir,
|
||||
device,
|
||||
dtype,
|
||||
model_type,
|
||||
engine,
|
||||
image_encoder_onnx_path,
|
||||
image_decoder_onnx_path,
|
||||
image_decoder_multi_onnx_path,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
predictor.set_image(image)
|
||||
prefix = f"sam2_demo_{engine}_"
|
||||
|
||||
# The model returns masks, quality predictions for those masks,
|
||||
# and low resolution mask logits that can be passed to the next iteration of prediction.
|
||||
# With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where
|
||||
# scores gives the model's own estimation of the quality of these masks.
|
||||
# For ambiguous prompts such as a single point, it is recommended to use multimask_output=True
|
||||
# even if only a single mask is desired;
|
||||
input_point = np.array([[500, 375]])
|
||||
input_label = np.array([1])
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
multimask_output=True,
|
||||
)
|
||||
|
||||
sorted_ind = np.argsort(scores)[::-1]
|
||||
masks = masks[sorted_ind]
|
||||
scores = scores[sorted_ind]
|
||||
logits = logits[sorted_ind]
|
||||
|
||||
image_files = []
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
point_coords=input_point,
|
||||
input_labels=input_label,
|
||||
borders=True,
|
||||
output_image_file_prefix=prefix + "multimask",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# Multiple points.
|
||||
input_point = np.array([[500, 375], [1125, 625]])
|
||||
input_label = np.array([1, 1])
|
||||
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
mask_input=mask_input[None, :, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
point_coords=input_point,
|
||||
input_labels=input_label,
|
||||
output_image_file_prefix=prefix + "multi_points",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# Specify a window and a background point.
|
||||
input_point = np.array([[500, 375], [1125, 625]])
|
||||
input_label = np.array([1, 0])
|
||||
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
mask_input=mask_input[None, :, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
point_coords=input_point,
|
||||
input_labels=input_label,
|
||||
output_image_file_prefix=prefix + "background_point",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# Take a box as input
|
||||
input_box = np.array([425, 600, 700, 875])
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_box[None, :],
|
||||
multimask_output=False,
|
||||
)
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
box_coords=input_box,
|
||||
output_image_file_prefix=prefix + "box",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# Combining points and boxes
|
||||
input_box = np.array([425, 600, 700, 875])
|
||||
input_point = np.array([[575, 750]])
|
||||
input_label = np.array([0])
|
||||
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=input_point,
|
||||
point_labels=input_label,
|
||||
box=input_box,
|
||||
multimask_output=False,
|
||||
)
|
||||
show_masks(
|
||||
image,
|
||||
masks,
|
||||
scores,
|
||||
box_coords=input_box,
|
||||
point_coords=input_point,
|
||||
input_labels=input_label,
|
||||
output_image_file_prefix=prefix + "box_and_point",
|
||||
image_files=image_files,
|
||||
)
|
||||
|
||||
# TODO: support batched prompt inputs
|
||||
if enable_batch:
|
||||
input_boxes = np.array(
|
||||
[
|
||||
[75, 275, 1725, 850],
|
||||
[425, 600, 700, 875],
|
||||
[1375, 550, 1650, 800],
|
||||
[1240, 675, 1400, 750],
|
||||
]
|
||||
)
|
||||
masks, scores, _ = predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.imshow(image)
|
||||
for mask in masks:
|
||||
show_mask(mask.squeeze(0), plt.gca(), random_color=True)
|
||||
for box in input_boxes:
|
||||
show_box(box, plt.gca())
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
plt.savefig(prefix + "batch_prompt.png")
|
||||
image_files.append(prefix + "batch_prompt.png")
|
||||
return image_files
|
||||
|
||||
|
||||
def show_all_images(left_images, right_images, suffix=""):
|
||||
# Show images in two rows since display screen is horizontal in most cases.
|
||||
fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80))
|
||||
for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images, strict=False)):
|
||||
left_img = mpimg.imread(left_img_path)
|
||||
right_img = mpimg.imread(right_img_path)
|
||||
|
||||
axes[0, i].imshow(left_img)
|
||||
axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
|
||||
axes[0, i].axis("off")
|
||||
axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0])
|
||||
|
||||
axes[1, i].imshow(right_img)
|
||||
axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
|
||||
axes[1, i].axis("off")
|
||||
axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000)
|
||||
plt.show()
|
||||
+279
@@ -0,0 +1,279 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from sam2_utils import decoder_shape_dict, encoder_shape_dict
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_ort_session(
|
||||
onnx_path: str,
|
||||
session_options=None,
|
||||
provider="CUDAExecutionProvider",
|
||||
enable_cuda_graph=False,
|
||||
use_tf32=True,
|
||||
) -> InferenceSession:
|
||||
if provider == "CUDAExecutionProvider":
|
||||
device_id = torch.cuda.current_device()
|
||||
provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph)
|
||||
provider_options["use_tf32"] = int(use_tf32)
|
||||
providers = [(provider, provider_options), "CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
logger.info("Using providers: %s", providers)
|
||||
return InferenceSession(onnx_path, session_options, providers=providers)
|
||||
|
||||
|
||||
def create_session(
|
||||
onnx_path: str,
|
||||
session_options=None,
|
||||
provider="CUDAExecutionProvider",
|
||||
device: str | torch.device = "cuda",
|
||||
enable_cuda_graph=False,
|
||||
) -> CudaSession:
|
||||
ort_session = create_ort_session(
|
||||
onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True
|
||||
)
|
||||
cuda_session = CudaSession(ort_session, device=torch.device(device), enable_cuda_graph=enable_cuda_graph)
|
||||
return cuda_session
|
||||
|
||||
|
||||
class SAM2ImageOnnxPredictor(SAM2ImagePredictor):
|
||||
def __init__(
|
||||
self,
|
||||
sam_model: SAM2Base,
|
||||
image_encoder_onnx_path: str = "",
|
||||
image_decoder_onnx_path: str = "",
|
||||
image_decoder_multi_onnx_path: str = "",
|
||||
provider: str = "CUDAExecutionProvider",
|
||||
device: str | torch.device = "cuda",
|
||||
onnx_dtype: torch.dtype = torch.float32,
|
||||
mask_threshold=0.0,
|
||||
max_hole_area=0.0,
|
||||
max_sprinkle_area=0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Uses SAM-2 to compute the image embedding for an image, and then allow mask prediction given prompts.
|
||||
|
||||
Arguments:
|
||||
sam_model (SAM2Base): The model to use for mask prediction.
|
||||
onnx_directory (str): The path of the directory that contains encoder and decoder onnx models.
|
||||
onnx_dtype (torch.dtype): The data type to use for ONNX inputs.
|
||||
mask_threshold (float): The threshold to convert mask logits to binary masks. Default is 0.0.
|
||||
max_hole_area (float): If max_hole_area > 0, we fill small holes in up to
|
||||
the maximum area of max_hole_area in low_res_masks.
|
||||
max_sprinkle_area (float): If max_sprinkle_area > 0, we remove small sprinkles up to
|
||||
the maximum area of max_sprinkle_area in low_res_masks.
|
||||
"""
|
||||
super().__init__(
|
||||
sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area
|
||||
)
|
||||
|
||||
logger.debug("self.device=%s, device=%s", self.device, device)
|
||||
|
||||
# This model is exported by image_encoder.py.
|
||||
self.encoder_session = create_session(
|
||||
image_encoder_onnx_path,
|
||||
session_options=None,
|
||||
provider=provider,
|
||||
device=device,
|
||||
enable_cuda_graph=False,
|
||||
)
|
||||
self.onnx_dtype = onnx_dtype
|
||||
|
||||
# This model is exported by image_decoder.py. It outputs only one mask.
|
||||
self.decoder_session = create_session(
|
||||
image_decoder_onnx_path,
|
||||
session_options=None,
|
||||
provider=provider,
|
||||
device=device,
|
||||
enable_cuda_graph=False,
|
||||
)
|
||||
|
||||
# This model is exported by image_decoder.py. It outputs multiple (3) masks.
|
||||
self.decoder_session_multi_out = create_session(
|
||||
image_decoder_multi_onnx_path,
|
||||
session_options=None,
|
||||
provider=provider,
|
||||
device=device,
|
||||
enable_cuda_graph=False,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def set_image(self, image: np.ndarray | Image):
|
||||
"""
|
||||
Calculates the image embeddings for the provided image.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray or PIL Image): The input image to embed in RGB format.
|
||||
The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255].
|
||||
"""
|
||||
self.reset_predictor()
|
||||
# Transform the image to the form expected by the model
|
||||
if isinstance(image, np.ndarray):
|
||||
# For numpy array image, we assume (HxWxC) format.
|
||||
self._orig_hw = [image.shape[:2]]
|
||||
elif isinstance(image, Image):
|
||||
w, h = image.size
|
||||
self._orig_hw = [(h, w)]
|
||||
else:
|
||||
raise NotImplementedError("Image format not supported")
|
||||
|
||||
input_image = self._transforms(image)
|
||||
input_image = input_image[None, ...].to(self.device)
|
||||
|
||||
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
|
||||
f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
||||
)
|
||||
|
||||
# Computing image embeddings for the provided image
|
||||
io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3])
|
||||
self.encoder_session.allocate_buffers(io_shapes)
|
||||
|
||||
feed_dict = {"image": input_image.to(self.onnx_dtype).to(self.device)}
|
||||
|
||||
for key, value in feed_dict.items():
|
||||
logger.debug(f"{key}: {value.shape}, {value.dtype}")
|
||||
logger.debug(f"encoder onnx: {self.encoder_session.ort_session._model_path}")
|
||||
|
||||
ort_outputs = self.encoder_session.infer(feed_dict)
|
||||
|
||||
self._features = {
|
||||
"image_embed": ort_outputs["image_embeddings"],
|
||||
"high_res_feats": [ort_outputs[f"image_features_{i}"] for i in range(2)],
|
||||
}
|
||||
self._is_image_set = True
|
||||
logging.info("Image embeddings computed.")
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict(
|
||||
self,
|
||||
point_coords: torch.Tensor | None,
|
||||
point_labels: torch.Tensor | None,
|
||||
boxes: torch.Tensor | None = None,
|
||||
mask_input: torch.Tensor | None = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
img_idx: int = -1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks for the given input prompts, using the currently set image.
|
||||
Input prompts are batched torch tensors and are expected to already be
|
||||
transformed to the input frame using SAM2Transforms.
|
||||
|
||||
Arguments:
|
||||
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
||||
model. Each point is in (X,Y) in pixels.
|
||||
point_labels (torch.Tensor or None): A BxN array of labels for the
|
||||
point prompts. 1 indicates a foreground point and 0 indicates a
|
||||
background point.
|
||||
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
||||
model, in XYXY format.
|
||||
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
||||
for SAM, H=W=256. Masks returned by a previous iteration of the
|
||||
predict method do not need further transformation.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
return_logits (bool): If true, returns un-thresholded masks logits
|
||||
instead of a binary mask.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(torch.Tensor): An array of shape BxC containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
||||
of masks and H=W=256. These low res logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
"""
|
||||
assert not return_logits # onnx model is exported for returning bool masks.
|
||||
|
||||
if not self._is_image_set:
|
||||
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
||||
|
||||
if point_coords is not None:
|
||||
concat_points = (point_coords, point_labels)
|
||||
else:
|
||||
concat_points = None
|
||||
|
||||
# Embed prompts
|
||||
if boxes is not None:
|
||||
box_coords = boxes.reshape(-1, 2, 2)
|
||||
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
||||
box_labels = box_labels.repeat(boxes.size(0), 1)
|
||||
# we merge "boxes" and "points" into a single "concat_points" input (where
|
||||
# boxes are added at the beginning) to sam_prompt_encoder
|
||||
if concat_points is not None:
|
||||
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
||||
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
||||
concat_points = (concat_coords, concat_labels)
|
||||
else:
|
||||
concat_points = (box_coords, box_labels)
|
||||
|
||||
assert concat_points is not None
|
||||
num_labels = concat_points[0].shape[0]
|
||||
shape_dict = decoder_shape_dict(
|
||||
original_image_height=self._orig_hw[img_idx][0],
|
||||
original_image_width=self._orig_hw[img_idx][1],
|
||||
num_labels=num_labels,
|
||||
max_points=concat_points[0].shape[1],
|
||||
num_masks=3 if multimask_output else 1,
|
||||
)
|
||||
if multimask_output:
|
||||
decoder_session = self.decoder_session_multi_out
|
||||
else:
|
||||
decoder_session = self.decoder_session
|
||||
|
||||
decoder_session.allocate_buffers(shape_dict)
|
||||
|
||||
image_features_0 = self._features["high_res_feats"][0][img_idx].unsqueeze(0)
|
||||
image_features_1 = self._features["high_res_feats"][1][img_idx].unsqueeze(0)
|
||||
image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0)
|
||||
|
||||
if mask_input is None:
|
||||
input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=self.onnx_dtype, device=self.device)
|
||||
has_input_masks = torch.zeros(num_labels, dtype=self.onnx_dtype, device=self.device)
|
||||
else:
|
||||
input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1)
|
||||
has_input_masks = torch.ones(num_labels, dtype=self.onnx_dtype, device=self.device)
|
||||
|
||||
feed_dict = {
|
||||
"image_embeddings": image_embeddings.contiguous().to(dtype=self.onnx_dtype).to(self.device),
|
||||
"image_features_0": image_features_0.contiguous().to(dtype=self.onnx_dtype).to(self.device),
|
||||
"image_features_1": image_features_1.contiguous().to(dtype=self.onnx_dtype).to(self.device),
|
||||
"point_coords": concat_points[0].to(dtype=self.onnx_dtype).to(self.device),
|
||||
"point_labels": concat_points[1].to(dtype=torch.int32).to(self.device),
|
||||
"input_masks": input_masks.to(dtype=self.onnx_dtype).to(self.device),
|
||||
"has_input_masks": has_input_masks.to(dtype=self.onnx_dtype).to(self.device),
|
||||
"original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device),
|
||||
}
|
||||
|
||||
for key, value in feed_dict.items():
|
||||
logger.debug(f"{key}: {value.shape}, {value.dtype}")
|
||||
logger.debug(f"decoder onnx: {self.decoder_session.ort_session._model_path}")
|
||||
|
||||
ort_outputs = decoder_session.infer(feed_dict)
|
||||
|
||||
masks = ort_outputs["masks"]
|
||||
iou_predictions = ort_outputs["iou_predictions"]
|
||||
low_res_masks = ort_outputs["low_res_masks"]
|
||||
|
||||
return torch.Tensor(masks), torch.Tensor(iou_predictions), torch.Tensor(low_res_masks)
|
||||
+147
@@ -0,0 +1,147 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (R) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
|
||||
import torch
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.modeling.sam2_base import SAM2Base
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_model_cfg(model_type) -> str:
|
||||
assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
|
||||
if model_type == "sam2_hiera_tiny":
|
||||
model_cfg = "sam2_hiera_t.yaml"
|
||||
elif model_type == "sam2_hiera_small":
|
||||
model_cfg = "sam2_hiera_s.yaml"
|
||||
elif model_type == "sam2_hiera_base_plus":
|
||||
model_cfg = "sam2_hiera_b+.yaml"
|
||||
else:
|
||||
model_cfg = "sam2_hiera_l.yaml"
|
||||
return model_cfg
|
||||
|
||||
|
||||
def load_sam2_model(sam2_dir, model_type, device: str | torch.device = "cpu") -> SAM2Base:
|
||||
checkpoints_dir = os.path.join(sam2_dir, "checkpoints")
|
||||
sam2_config_dir = os.path.join(sam2_dir, "sam2_configs")
|
||||
if not os.path.exists(sam2_dir):
|
||||
raise FileNotFoundError(f"{sam2_dir} does not exist. Please specify --sam2_dir correctly.")
|
||||
|
||||
if not os.path.exists(checkpoints_dir):
|
||||
raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.")
|
||||
|
||||
if not os.path.exists(sam2_config_dir):
|
||||
raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.")
|
||||
|
||||
checkpoint_path = os.path.join(checkpoints_dir, f"{model_type}.pt")
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"{checkpoint_path} does not exist. Please download checkpoints under the directory.")
|
||||
|
||||
if sam2_dir not in sys.path:
|
||||
sys.path.append(sam2_dir)
|
||||
|
||||
model_cfg = _get_model_cfg(model_type)
|
||||
sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
|
||||
return sam2_model
|
||||
|
||||
|
||||
def sam2_onnx_path(output_dir, model_type, component, multimask_output=False, suffix=""):
|
||||
if component == "image_encoder":
|
||||
return os.path.join(output_dir, f"{model_type}_image_encoder{suffix}.onnx")
|
||||
elif component == "mask_decoder":
|
||||
return os.path.join(output_dir, f"{model_type}_mask_decoder{suffix}.onnx")
|
||||
elif component == "prompt_encoder":
|
||||
return os.path.join(output_dir, f"{model_type}_prompt_encoder{suffix}.onnx")
|
||||
else:
|
||||
assert component == "image_decoder"
|
||||
return os.path.join(
|
||||
output_dir, f"{model_type}_image_decoder" + ("_multi" if multimask_output else "") + f"{suffix}.onnx"
|
||||
)
|
||||
|
||||
|
||||
def encoder_shape_dict(batch_size: int, height: int, width: int) -> Mapping[str, list[int]]:
|
||||
assert height == 1024 and width == 1024, "Only 1024x1024 images are supported."
|
||||
return {
|
||||
"image": [batch_size, 3, height, width],
|
||||
"image_features_0": [batch_size, 32, height // 4, width // 4],
|
||||
"image_features_1": [batch_size, 64, height // 8, width // 8],
|
||||
"image_embeddings": [batch_size, 256, height // 16, width // 16],
|
||||
}
|
||||
|
||||
|
||||
def decoder_shape_dict(
|
||||
original_image_height: int,
|
||||
original_image_width: int,
|
||||
num_labels: int = 1,
|
||||
max_points: int = 16,
|
||||
num_masks: int = 1,
|
||||
) -> dict:
|
||||
height: int = 1024
|
||||
width: int = 1024
|
||||
return {
|
||||
"image_features_0": [1, 32, height // 4, width // 4],
|
||||
"image_features_1": [1, 64, height // 8, width // 8],
|
||||
"image_embeddings": [1, 256, height // 16, width // 16],
|
||||
"point_coords": [num_labels, max_points, 2],
|
||||
"point_labels": [num_labels, max_points],
|
||||
"input_masks": [num_labels, 1, height // 4, width // 4],
|
||||
"has_input_masks": [num_labels],
|
||||
"original_image_size": [2],
|
||||
"masks": [num_labels, num_masks, original_image_height, original_image_width],
|
||||
"iou_predictions": [num_labels, num_masks],
|
||||
"low_res_masks": [num_labels, num_masks, height // 4, width // 4],
|
||||
}
|
||||
|
||||
|
||||
def compare_tensors_with_tolerance(
|
||||
name: str,
|
||||
tensor1: torch.Tensor,
|
||||
tensor2: torch.Tensor,
|
||||
atol=5e-3,
|
||||
rtol=1e-4,
|
||||
mismatch_percentage_tolerance=0.1,
|
||||
) -> bool:
|
||||
assert tensor1.shape == tensor2.shape
|
||||
a = tensor1.clone().float()
|
||||
b = tensor2.clone().float()
|
||||
|
||||
differences = torch.abs(a - b)
|
||||
mismatch_count = (differences > (rtol * torch.max(torch.abs(a), torch.abs(b)) + atol)).sum().item()
|
||||
|
||||
total_elements = a.numel()
|
||||
mismatch_percentage = (mismatch_count / total_elements) * 100
|
||||
|
||||
passed = mismatch_percentage < mismatch_percentage_tolerance
|
||||
|
||||
log_func = logger.error if not passed else logger.info
|
||||
log_func(
|
||||
"%s: mismatched elements percentage %.2f (%d/%d). Verification %s (threshold=%.2f).",
|
||||
name,
|
||||
mismatch_percentage,
|
||||
mismatch_count,
|
||||
total_elements,
|
||||
"passed" if passed else "failed",
|
||||
mismatch_percentage_tolerance,
|
||||
)
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor:
|
||||
image = torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32).cpu()
|
||||
return image
|
||||
|
||||
|
||||
def setup_logger(verbose=True):
|
||||
if verbose:
|
||||
logging.basicConfig(format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s")
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
else:
|
||||
logging.basicConfig(format="[%(message)s")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+1522
File diff suppressed because it is too large
Load Diff
+426
@@ -0,0 +1,426 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import gc
|
||||
import importlib.util
|
||||
import time
|
||||
from statistics import mean
|
||||
|
||||
import torch
|
||||
from demo_utils import PipelineInfo
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
DiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from engine_builder import EngineType, get_engine_paths
|
||||
from pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
|
||||
"""
|
||||
Benchmark script for SDXL-Turbo with control net for engines like PyTorch or Stable Fast.
|
||||
|
||||
Setup for Stable Fast (see https://github.com/chengzeyi/stable-fast/blob/main/README.md for more info):
|
||||
git clone https://github.com/chengzeyi/stable-fast.git
|
||||
cd stable-fast
|
||||
git submodule update --init
|
||||
pip3 install torch torchvision torchaudio ninja
|
||||
pip3 install -e '.[dev,xformers,triton,transformers,diffusers]' -v
|
||||
sudo apt install libgoogle-perftools-dev
|
||||
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so
|
||||
"""
|
||||
|
||||
|
||||
def get_canny_image():
|
||||
import cv2 # noqa: PLC0415
|
||||
import numpy as np # noqa: PLC0415
|
||||
from PIL import Image # noqa: PLC0415
|
||||
|
||||
# Test Image can be downloaded from https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png
|
||||
image = Image.open("input_image_vermeer.png").convert("RGB")
|
||||
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
return Image.fromarray(image)
|
||||
|
||||
|
||||
def compile_stable_fast(pipeline, enable_cuda_graph=True):
|
||||
from sfast.compilers.stable_diffusion_pipeline_compiler import CompilationConfig, compile # noqa: PLC0415
|
||||
|
||||
config = CompilationConfig.Default()
|
||||
|
||||
if importlib.util.find_spec("xformers") is not None:
|
||||
config.enable_xformers = True
|
||||
|
||||
if importlib.util.find_spec("triton") is not None:
|
||||
config.enable_triton = True
|
||||
|
||||
config.enable_cuda_graph = enable_cuda_graph
|
||||
|
||||
pipeline = compile(pipeline, config)
|
||||
return pipeline
|
||||
|
||||
|
||||
def compile_torch(pipeline, use_nhwc=False):
|
||||
if use_nhwc:
|
||||
pipeline.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
if hasattr(pipeline, "controlnet"):
|
||||
if use_nhwc:
|
||||
pipeline.controlnet.to(memory_format=torch.channels_last)
|
||||
pipeline.controlnet = torch.compile(pipeline.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
return pipeline
|
||||
|
||||
|
||||
def load_pipeline(name, engine, use_control_net=False, use_nhwc=False, enable_cuda_graph=True):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
before_memory = torch.cuda.memory_allocated()
|
||||
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(name, subfolder="scheduler")
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
if use_control_net:
|
||||
assert "xl" in name
|
||||
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
name,
|
||||
controlnet=controlnet,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
else:
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
name,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
pipeline.safety_checker = None
|
||||
|
||||
gc.collect()
|
||||
after_memory = torch.cuda.memory_allocated()
|
||||
print(f"Loaded model with {after_memory - before_memory} bytes allocated")
|
||||
|
||||
if engine == "stable_fast":
|
||||
pipeline = compile_stable_fast(pipeline, enable_cuda_graph=enable_cuda_graph)
|
||||
elif engine == "torch":
|
||||
pipeline = compile_torch(pipeline, use_nhwc=use_nhwc)
|
||||
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
return pipeline
|
||||
|
||||
|
||||
def get_prompt():
|
||||
return "little cute gremlin wearing a jacket, cinematic, vivid colors, intricate masterpiece, golden ratio, highly detailed"
|
||||
|
||||
|
||||
def load_ort_cuda_pipeline(name, engine, use_control_net=False, enable_cuda_graph=True, work_dir="."):
|
||||
version = PipelineInfo.supported_models()[name]
|
||||
guidance_scale = 0.0
|
||||
pipeline_info = PipelineInfo(
|
||||
version,
|
||||
use_vae=True,
|
||||
use_fp16_vae=True,
|
||||
do_classifier_free_guidance=(guidance_scale > 1.0),
|
||||
controlnet=["canny"] if use_control_net else [],
|
||||
)
|
||||
|
||||
engine_type = EngineType.ORT_CUDA if engine == "ort_cuda" else EngineType.ORT_TRT
|
||||
onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(
|
||||
work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
pipeline_info,
|
||||
scheduler="EulerA",
|
||||
max_batch_size=32,
|
||||
use_cuda_graph=enable_cuda_graph,
|
||||
framework_model_dir=framework_model_dir,
|
||||
output_dir=output_dir,
|
||||
engine_type=engine_type,
|
||||
)
|
||||
|
||||
pipeline.backend.build_engines(
|
||||
engine_dir=engine_dir,
|
||||
framework_model_dir=framework_model_dir,
|
||||
onnx_dir=onnx_dir,
|
||||
device_id=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_ort_cuda(
|
||||
pipeline,
|
||||
batch_size=1,
|
||||
steps=4,
|
||||
control_image=None,
|
||||
warmup_runs=3,
|
||||
test_runs=10,
|
||||
seed=123,
|
||||
verbose=False,
|
||||
image_height=512,
|
||||
image_width=512,
|
||||
):
|
||||
if batch_size > 4 and pipeline.pipeline_info.version == "xl-1.0":
|
||||
pipeline.backend.enable_vae_slicing()
|
||||
|
||||
pipeline.load_resources(image_height, image_width, batch_size)
|
||||
|
||||
warmup_prompt = "warm up"
|
||||
for _ in range(warmup_runs):
|
||||
images, _ = pipeline.run(
|
||||
[warmup_prompt] * batch_size,
|
||||
[""] * batch_size,
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
denoising_steps=steps,
|
||||
guidance=0.0,
|
||||
seed=seed,
|
||||
controlnet_images=[control_image],
|
||||
controlnet_scales=torch.FloatTensor([0.5]),
|
||||
output_type="image",
|
||||
)
|
||||
assert len(images) == batch_size
|
||||
|
||||
generator = torch.Generator(device="cuda")
|
||||
generator.manual_seed(seed)
|
||||
|
||||
prompt = get_prompt()
|
||||
|
||||
latency_list = []
|
||||
images = None
|
||||
for _ in range(test_runs):
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
images, _ = pipeline.run(
|
||||
[prompt] * batch_size,
|
||||
[""] * batch_size,
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
denoising_steps=steps,
|
||||
guidance=0.0,
|
||||
seed=seed,
|
||||
controlnet_images=[control_image],
|
||||
controlnet_scales=torch.FloatTensor([0.5]),
|
||||
output_type="pil",
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
seconds = time.perf_counter() - start_time
|
||||
latency_list.append(seconds)
|
||||
|
||||
if verbose:
|
||||
print(latency_list)
|
||||
|
||||
return images, latency_list
|
||||
|
||||
|
||||
def test(pipeline, batch_size=1, steps=4, control_image=None, warmup_runs=3, test_runs=10, seed=123, verbose=False):
|
||||
control_net_args = {}
|
||||
if hasattr(pipeline, "controlnet"):
|
||||
control_net_args = {
|
||||
"image": control_image,
|
||||
"controlnet_conditioning_scale": 0.5,
|
||||
}
|
||||
|
||||
warmup_prompt = "warm up"
|
||||
for _ in range(warmup_runs):
|
||||
images = pipeline(
|
||||
prompt=warmup_prompt,
|
||||
num_inference_steps=steps,
|
||||
num_images_per_prompt=batch_size,
|
||||
guidance_scale=0.0,
|
||||
**control_net_args,
|
||||
).images
|
||||
assert len(images) == batch_size
|
||||
|
||||
generator = torch.Generator(device="cuda")
|
||||
generator.manual_seed(seed)
|
||||
|
||||
prompt = get_prompt()
|
||||
|
||||
latency_list = []
|
||||
images = None
|
||||
for _ in range(test_runs):
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=steps,
|
||||
num_images_per_prompt=batch_size,
|
||||
guidance_scale=0.0,
|
||||
generator=generator,
|
||||
**control_net_args,
|
||||
).images
|
||||
torch.cuda.synchronize()
|
||||
seconds = time.perf_counter() - start_time
|
||||
latency_list.append(seconds)
|
||||
|
||||
if verbose:
|
||||
print(latency_list)
|
||||
|
||||
return images, latency_list
|
||||
|
||||
|
||||
def arguments():
|
||||
import argparse # noqa: PLC0415
|
||||
|
||||
parser = argparse.ArgumentParser(description="Benchmark Stable Diffusion pipeline (optional control net for SDXL)")
|
||||
parser.add_argument(
|
||||
"--engine",
|
||||
type=str,
|
||||
default="torch",
|
||||
choices=["torch", "stable_fast", "ort_cuda", "ort_trt"],
|
||||
help="Backend engine: torch, stable_fast or ort_cuda",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
choices=list(PipelineInfo.supported_models().keys()),
|
||||
default="stabilityai/sdxl-turbo",
|
||||
help="Stable diffusion model name. Default is stabilityai/sdxl-turbo",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--work-dir",
|
||||
type=str,
|
||||
default=".",
|
||||
help="working directory for ort_cuda or ort_trt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_control_net",
|
||||
action="store_true",
|
||||
help="Use control net diffusers/controlnet-canny-sdxl-1.0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Denoising steps",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--warmup_runs",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of warmup runs before measurement",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_nhwc",
|
||||
action="store_true",
|
||||
help="use channel last format for torch compile",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--enable_cuda_graph",
|
||||
action="store_true",
|
||||
help="enable cuda graph for stable fast",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="print more information",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = arguments()
|
||||
|
||||
with torch.no_grad():
|
||||
if args.engine == "ort_cuda":
|
||||
pipeline = load_ort_cuda_pipeline(
|
||||
args.name,
|
||||
args.engine,
|
||||
use_control_net=args.use_control_net,
|
||||
enable_cuda_graph=args.enable_cuda_graph,
|
||||
work_dir=args.work_dir,
|
||||
)
|
||||
else:
|
||||
pipeline = load_pipeline(
|
||||
args.name,
|
||||
args.engine,
|
||||
use_control_net=args.use_control_net,
|
||||
use_nhwc=args.use_nhwc,
|
||||
enable_cuda_graph=args.enable_cuda_graph,
|
||||
)
|
||||
|
||||
canny_image = get_canny_image()
|
||||
|
||||
if args.engine == "ort_cuda":
|
||||
images, latency_list = test_ort_cuda(
|
||||
pipeline,
|
||||
args.batch_size,
|
||||
args.steps,
|
||||
control_image=canny_image,
|
||||
warmup_runs=args.warmup_runs,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
elif args.engine == "stable_fast":
|
||||
from sfast.utils.compute_precision import low_compute_precision # noqa: PLC0415
|
||||
|
||||
with low_compute_precision():
|
||||
images, latency_list = test(
|
||||
pipeline,
|
||||
args.batch_size,
|
||||
args.steps,
|
||||
control_image=canny_image,
|
||||
warmup_runs=args.warmup_runs,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
else:
|
||||
images, latency_list = test(
|
||||
pipeline,
|
||||
args.batch_size,
|
||||
args.steps,
|
||||
control_image=canny_image,
|
||||
warmup_runs=args.warmup_runs,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
# Save the first output image to inspect the result.
|
||||
if images:
|
||||
images[0].save(
|
||||
f"{args.engine}_{args.name.replace('/', '_')}_{args.batch_size}_{args.steps}_c{int(args.use_control_net)}.png"
|
||||
)
|
||||
|
||||
result = {
|
||||
"engine": args.engine,
|
||||
"batch_size": args.batch_size,
|
||||
"steps": args.steps,
|
||||
"control_net": args.use_control_net,
|
||||
"nhwc": args.use_nhwc,
|
||||
"enable_cuda_graph": args.enable_cuda_graph,
|
||||
"average_latency_in_ms": mean(latency_list) * 1000,
|
||||
}
|
||||
print(result)
|
||||
|
||||
|
||||
main()
|
||||
+102
@@ -0,0 +1,102 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
# Modified from TensorRT demo diffusion, which has the following license:
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import coloredlogs
|
||||
from cuda import cudart
|
||||
from demo_utils import (
|
||||
add_controlnet_arguments,
|
||||
arg_parser,
|
||||
get_metadata,
|
||||
load_pipelines,
|
||||
parse_arguments,
|
||||
process_controlnet_arguments,
|
||||
repeat_prompt,
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
controlnet_images, controlnet_scale = process_controlnet_arguments(args)
|
||||
|
||||
pipeline, refiner = load_pipelines(args)
|
||||
assert refiner is None
|
||||
|
||||
prompt, negative_prompt = repeat_prompt(args)
|
||||
batch_size = len(prompt)
|
||||
pipeline.load_resources(args.height, args.width, batch_size)
|
||||
|
||||
def run_inference(warmup=False):
|
||||
return pipeline.run(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
args.height,
|
||||
args.width,
|
||||
denoising_steps=args.denoising_steps,
|
||||
guidance=args.guidance,
|
||||
seed=args.seed,
|
||||
controlnet_images=controlnet_images,
|
||||
controlnet_scales=controlnet_scale,
|
||||
show_latency=not warmup,
|
||||
output_type="pil",
|
||||
deterministic=args.deterministic,
|
||||
)
|
||||
|
||||
if not args.disable_cuda_graph:
|
||||
# inference once to get cuda graph
|
||||
_, _ = run_inference(warmup=True)
|
||||
|
||||
print("[I] Warming up ..")
|
||||
for _ in range(args.num_warmup_runs):
|
||||
_, _ = run_inference(warmup=True)
|
||||
|
||||
print("[I] Running StableDiffusion pipeline")
|
||||
if args.nvtx_profile:
|
||||
cudart.cudaProfilerStart()
|
||||
images, perf_data = run_inference(warmup=False)
|
||||
if args.nvtx_profile:
|
||||
cudart.cudaProfilerStop()
|
||||
|
||||
metadata = get_metadata(args, False)
|
||||
metadata.update(pipeline.metadata())
|
||||
if perf_data:
|
||||
metadata.update(perf_data)
|
||||
metadata["images"] = len(images)
|
||||
print(metadata)
|
||||
pipeline.save_images(images, prompt, negative_prompt, metadata)
|
||||
|
||||
pipeline.teardown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
|
||||
|
||||
parser = arg_parser("Options for Stable Diffusion Demo")
|
||||
add_controlnet_arguments(parser)
|
||||
args = parse_arguments(is_xl=False, parser=parser)
|
||||
|
||||
if args.user_compute_stream:
|
||||
import torch
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
main(args)
|
||||
else:
|
||||
main(args)
|
||||
+268
@@ -0,0 +1,268 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
# Modified from TensorRT demo diffusion, which has the following license:
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import coloredlogs
|
||||
from cuda import cudart
|
||||
from demo_utils import (
|
||||
add_controlnet_arguments,
|
||||
arg_parser,
|
||||
get_metadata,
|
||||
load_pipelines,
|
||||
parse_arguments,
|
||||
process_controlnet_arguments,
|
||||
repeat_prompt,
|
||||
)
|
||||
|
||||
|
||||
def run_pipelines(
|
||||
args, base, refiner, prompt, negative_prompt, controlnet_image=None, controlnet_scale=None, is_warm_up=False
|
||||
):
|
||||
image_height = args.height
|
||||
image_width = args.width
|
||||
batch_size = len(prompt)
|
||||
base.load_resources(image_height, image_width, batch_size)
|
||||
if refiner:
|
||||
refiner.load_resources(image_height, image_width, batch_size)
|
||||
|
||||
def run_base_and_refiner(warmup=False):
|
||||
images, base_perf = base.run(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image_height,
|
||||
image_width,
|
||||
denoising_steps=args.denoising_steps,
|
||||
guidance=args.guidance,
|
||||
seed=args.seed,
|
||||
controlnet_images=controlnet_image,
|
||||
controlnet_scales=controlnet_scale,
|
||||
show_latency=not warmup,
|
||||
output_type="latent" if refiner else "pil",
|
||||
)
|
||||
if refiner is None:
|
||||
return images, base_perf
|
||||
|
||||
# Use same seed in base and refiner.
|
||||
seed = base.get_current_seed()
|
||||
|
||||
images, refiner_perf = refiner.run(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image_height,
|
||||
image_width,
|
||||
denoising_steps=args.refiner_denoising_steps,
|
||||
image=images,
|
||||
strength=args.strength,
|
||||
guidance=args.refiner_guidance,
|
||||
seed=seed,
|
||||
show_latency=not warmup,
|
||||
)
|
||||
|
||||
perf_data = None
|
||||
if base_perf and refiner_perf:
|
||||
perf_data = {"latency": base_perf["latency"] + refiner_perf["latency"]}
|
||||
perf_data.update({"base." + key: val for key, val in base_perf.items()})
|
||||
perf_data.update({"refiner." + key: val for key, val in refiner_perf.items()})
|
||||
|
||||
return images, perf_data
|
||||
|
||||
if not args.disable_cuda_graph:
|
||||
# inference once to get cuda graph
|
||||
_, _ = run_base_and_refiner(warmup=True)
|
||||
|
||||
if args.num_warmup_runs > 0:
|
||||
print("[I] Warming up ..")
|
||||
for _ in range(args.num_warmup_runs):
|
||||
_, _ = run_base_and_refiner(warmup=True)
|
||||
|
||||
if is_warm_up:
|
||||
return
|
||||
|
||||
print("[I] Running StableDiffusion XL pipeline")
|
||||
if args.nvtx_profile:
|
||||
cudart.cudaProfilerStart()
|
||||
images, perf_data = run_base_and_refiner(warmup=False)
|
||||
if args.nvtx_profile:
|
||||
cudart.cudaProfilerStop()
|
||||
|
||||
if refiner:
|
||||
print("|----------------|--------------|")
|
||||
print("| {:^14} | {:>9.2f} ms |".format("e2e", perf_data["latency"]))
|
||||
print("|----------------|--------------|")
|
||||
|
||||
metadata = get_metadata(args, True)
|
||||
metadata.update({"base." + key: val for key, val in base.metadata().items()})
|
||||
if refiner:
|
||||
metadata.update({"refiner." + key: val for key, val in refiner.metadata().items()})
|
||||
if perf_data:
|
||||
metadata.update(perf_data)
|
||||
metadata["images"] = len(images)
|
||||
print(metadata)
|
||||
(refiner or base).save_images(images, prompt, negative_prompt, metadata)
|
||||
|
||||
|
||||
def run_demo(args):
|
||||
"""Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""
|
||||
controlnet_image, controlnet_scale = process_controlnet_arguments(args)
|
||||
prompt, negative_prompt = repeat_prompt(args)
|
||||
batch_size = len(prompt)
|
||||
base, refiner = load_pipelines(args, batch_size)
|
||||
run_pipelines(args, base, refiner, prompt, negative_prompt, controlnet_image, controlnet_scale)
|
||||
base.teardown()
|
||||
if refiner:
|
||||
refiner.teardown()
|
||||
|
||||
|
||||
def run_dynamic_shape_demo(args):
|
||||
"""
|
||||
Run demo of generating images with different settings with ORT CUDA provider.
|
||||
Try "python demo_txt2img_xl.py --max-cuda-graphs 3 --user-compute-stream" to see the effect of multiple CUDA graphs.
|
||||
"""
|
||||
args.engine = "ORT_CUDA"
|
||||
base, refiner = load_pipelines(args, 1)
|
||||
|
||||
prompts = [
|
||||
"starry night over Golden Gate Bridge by van gogh",
|
||||
"beautiful photograph of Mt. Fuji during cherry blossom",
|
||||
"little cute gremlin sitting on a bed, cinematic",
|
||||
"cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
|
||||
"beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation",
|
||||
"blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
|
||||
"An astronaut riding a rainbow unicorn, cinematic, dramatic",
|
||||
"close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm",
|
||||
]
|
||||
|
||||
# batch size, height, width, scheduler, steps, prompt, seed, guidance, refiner scheduler, refiner steps, refiner strength
|
||||
configs = [
|
||||
(1, 832, 1216, "UniPC", 8, prompts[0], None, 5.0, "UniPC", 10, 0.3),
|
||||
(1, 1024, 1024, "DDIM", 24, prompts[1], None, 5.0, "DDIM", 30, 0.3),
|
||||
(1, 1216, 832, "EulerA", 16, prompts[2], 1716921396712843, 5.0, "EulerA", 10, 0.3),
|
||||
(1, 1344, 768, "EulerA", 24, prompts[3], 123698071912362, 5.0, "EulerA", 20, 0.3),
|
||||
(2, 640, 1536, "UniPC", 16, prompts[4], 4312973633252712, 5.0, "UniPC", 10, 0.3),
|
||||
(2, 1152, 896, "DDIM", 24, prompts[5], 1964684802882906, 5.0, "UniPC", 20, 0.3),
|
||||
]
|
||||
|
||||
# In testing LCM, refiner is disabled so the settings of refiner is not used.
|
||||
if args.lcm:
|
||||
configs = [
|
||||
(1, 1024, 1024, "LCM", 8, prompts[6], None, 1.0, "UniPC", 20, 0.3),
|
||||
(1, 1216, 832, "LCM", 6, prompts[7], 1337, 1.0, "UniPC", 20, 0.3),
|
||||
]
|
||||
|
||||
# Warm up each combination of (batch size, height, width) once before serving.
|
||||
args.prompt = ["warm up"]
|
||||
args.num_warmup_runs = 1
|
||||
for batch_size, height, width, _, _, _, _, _, _, _, _ in configs:
|
||||
args.batch_size = batch_size
|
||||
args.height = height
|
||||
args.width = width
|
||||
print(f"\nWarm up batch_size={batch_size}, height={height}, width={width}")
|
||||
prompt, negative_prompt = repeat_prompt(args)
|
||||
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=True)
|
||||
|
||||
# Run pipeline on a list of prompts.
|
||||
args.num_warmup_runs = 0
|
||||
for (
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
scheduler,
|
||||
steps,
|
||||
example_prompt,
|
||||
seed,
|
||||
guidance,
|
||||
refiner_scheduler,
|
||||
refiner_denoising_steps,
|
||||
strength,
|
||||
) in configs:
|
||||
args.prompt = [example_prompt]
|
||||
args.batch_size = batch_size
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.scheduler = scheduler
|
||||
args.denoising_steps = steps
|
||||
args.seed = seed
|
||||
args.guidance = guidance
|
||||
args.refiner_scheduler = refiner_scheduler
|
||||
args.refiner_denoising_steps = refiner_denoising_steps
|
||||
args.strength = strength
|
||||
base.set_scheduler(scheduler)
|
||||
if refiner:
|
||||
refiner.set_scheduler(refiner_scheduler)
|
||||
prompt, negative_prompt = repeat_prompt(args)
|
||||
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)
|
||||
|
||||
base.teardown()
|
||||
if refiner:
|
||||
refiner.teardown()
|
||||
|
||||
|
||||
def run_turbo_demo(args):
|
||||
"""Run demo of generating images with test prompts with ORT CUDA provider."""
|
||||
args.engine = "ORT_CUDA"
|
||||
base, refiner = load_pipelines(args, 1)
|
||||
|
||||
from datasets import load_dataset # noqa: PLC0415
|
||||
|
||||
dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts")
|
||||
num_rows = dataset["test"].num_rows
|
||||
batch_size = args.batch_size
|
||||
num_batch = int(num_rows / batch_size)
|
||||
args.batch_size = 1
|
||||
for i in range(num_batch):
|
||||
args.prompt = [dataset["test"][i]["Prompt"] for i in range(i * batch_size, (i + 1) * batch_size)]
|
||||
base.set_scheduler(args.scheduler)
|
||||
if refiner:
|
||||
refiner.set_scheduler(args.refiner_scheduler)
|
||||
prompt, negative_prompt = repeat_prompt(args)
|
||||
run_pipelines(args, base, refiner, prompt, negative_prompt, is_warm_up=False)
|
||||
|
||||
base.teardown()
|
||||
if refiner:
|
||||
refiner.teardown()
|
||||
|
||||
|
||||
def main(args):
|
||||
no_prompt = isinstance(args.prompt, list) and len(args.prompt) == 1 and not args.prompt[0]
|
||||
if no_prompt:
|
||||
if args.version == "xl-turbo":
|
||||
run_turbo_demo(args)
|
||||
else:
|
||||
run_dynamic_shape_demo(args)
|
||||
else:
|
||||
run_demo(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
|
||||
|
||||
parser = arg_parser("Options for Stable Diffusion XL Demo")
|
||||
add_controlnet_arguments(parser)
|
||||
args = parse_arguments(is_xl=True, parser=parser)
|
||||
|
||||
if args.user_compute_stream:
|
||||
import torch
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
main(args)
|
||||
else:
|
||||
main(args)
|
||||
+778
@@ -0,0 +1,778 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
# Modified from TensorRT demo diffusion, which has the following license:
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from typing import Any
|
||||
|
||||
import controlnet_aux
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from cuda import cudart
|
||||
from diffusion_models import PipelineInfo
|
||||
from engine_builder import EngineType, get_engine_paths, get_engine_type
|
||||
from PIL import Image
|
||||
from pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
|
||||
|
||||
class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
|
||||
pass
|
||||
|
||||
|
||||
def arg_parser(description: str):
|
||||
return argparse.ArgumentParser(
|
||||
description=description,
|
||||
formatter_class=RawTextArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
|
||||
def set_default_arguments(args):
|
||||
# set default value for some arguments if not provided
|
||||
if args.height is None:
|
||||
args.height = PipelineInfo.default_resolution(args.version)
|
||||
|
||||
if args.width is None:
|
||||
args.width = PipelineInfo.default_resolution(args.version)
|
||||
|
||||
is_lcm = (args.version == "xl-1.0" and args.lcm) or "lcm" in args.lora_weights
|
||||
is_turbo = args.version in ["sd-turbo", "xl-turbo"]
|
||||
if args.denoising_steps is None:
|
||||
args.denoising_steps = 4 if is_turbo else 8 if is_lcm else (30 if args.version == "xl-1.0" else 50)
|
||||
|
||||
if args.scheduler is None:
|
||||
args.scheduler = "LCM" if (is_lcm or is_turbo) else ("EulerA" if args.version == "xl-1.0" else "DDIM")
|
||||
|
||||
if args.guidance is None:
|
||||
args.guidance = 0.0 if (is_lcm or is_turbo) else (5.0 if args.version == "xl-1.0" else 7.5)
|
||||
|
||||
|
||||
def parse_arguments(is_xl: bool, parser):
|
||||
engines = ["ORT_CUDA", "ORT_TRT", "TRT", "TORCH"]
|
||||
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--engine",
|
||||
type=str,
|
||||
default=engines[0],
|
||||
choices=engines,
|
||||
help="Backend engine in {engines}. "
|
||||
"ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT",
|
||||
)
|
||||
|
||||
supported_versions = PipelineInfo.supported_versions(is_xl)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--version",
|
||||
type=str,
|
||||
default="xl-1.0" if is_xl else "1.5",
|
||||
choices=supported_versions,
|
||||
help="Version of Stable Diffusion" + (" XL." if is_xl else "."),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-y",
|
||||
"--height",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Height of image to generate (must be multiple of 8).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x", "--width", type=int, default=None, help="Height of image to generate (must be multiple of 8)."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["DDIM", "EulerA", "UniPC", "LCM"],
|
||||
help="Scheduler for diffusion process" + " of base" if is_xl else "",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-wd",
|
||||
"--work-dir",
|
||||
default=".",
|
||||
help="Root Directory to store torch or ONNX models, built engines and output images etc.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--engine-dir",
|
||||
default=None,
|
||||
help="Root Directory to store built engines or optimized ONNX models etc.",
|
||||
)
|
||||
|
||||
parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.")
|
||||
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--negative-prompt",
|
||||
nargs="*",
|
||||
default=[""],
|
||||
help="Optional negative prompt(s) to guide the image generation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2, 4, 8, 16],
|
||||
help="Number of times to repeat the prompt (batch size multiplier).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--denoising-steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of denoising steps" + (" in base." if is_xl else "."),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--guidance",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-ls", "--lora-scale", type=float, default=1, help="Scale of LoRA weights, default 1 (must between 0 and 1)"
|
||||
)
|
||||
parser.add_argument("-lw", "--lora-weights", type=str, default="", help="LoRA weights to apply in the base model")
|
||||
|
||||
if is_xl:
|
||||
parser.add_argument(
|
||||
"--lcm",
|
||||
action="store_true",
|
||||
help="Use fine-tuned latent consistency model to replace the UNet in base.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-rs",
|
||||
"--refiner-scheduler",
|
||||
type=str,
|
||||
default="EulerA",
|
||||
choices=["DDIM", "EulerA", "UniPC"],
|
||||
help="Scheduler for diffusion process of refiner.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-rg",
|
||||
"--refiner-guidance",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Guidance scale used in refiner.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-rd",
|
||||
"--refiner-denoising-steps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of denoising steps in refiner. Note that actual steps is refiner_denoising_steps * strength.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="A value between 0 and 1. The higher the value less the final image similar to the seed image.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--enable-refiner",
|
||||
action="store_true",
|
||||
help="Enable SDXL refiner to refine image from base pipeline.",
|
||||
)
|
||||
|
||||
# ONNX export
|
||||
parser.add_argument(
|
||||
"--onnx-opset",
|
||||
type=int,
|
||||
default=None,
|
||||
choices=range(14, 18),
|
||||
help="Select ONNX opset version to target for exported models.",
|
||||
)
|
||||
|
||||
# Engine build options.
|
||||
parser.add_argument(
|
||||
"-db",
|
||||
"--build-dynamic-batch",
|
||||
action="store_true",
|
||||
help="Build TensorRT engines to support dynamic batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ds",
|
||||
"--build-dynamic-shape",
|
||||
action="store_true",
|
||||
help="Build TensorRT engines to support dynamic image sizes.",
|
||||
)
|
||||
parser.add_argument("--max-batch-size", type=int, default=None, choices=[1, 2, 4, 8, 16, 32], help="Max batch size")
|
||||
|
||||
# Inference related options
|
||||
parser.add_argument(
|
||||
"-nw", "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance."
|
||||
)
|
||||
parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="use deterministic algorithms.")
|
||||
parser.add_argument("-dc", "--disable-cuda-graph", action="store_true", help="Disable cuda graph.")
|
||||
|
||||
parser.add_argument("--framework-model-dir", default=None, help="framework model directory")
|
||||
|
||||
group = parser.add_argument_group("Options for ORT_CUDA engine only")
|
||||
group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.")
|
||||
group.add_argument("--max-cuda-graphs", type=int, default=1, help="Max number of cuda graphs to use. Default 1.")
|
||||
group.add_argument("--user-compute-stream", action="store_true", help="Use user compute stream.")
|
||||
|
||||
# TensorRT only options
|
||||
group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only")
|
||||
group.add_argument(
|
||||
"--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
set_default_arguments(args)
|
||||
|
||||
# Validate image dimensions
|
||||
if args.height % 64 != 0 or args.width % 64 != 0:
|
||||
raise ValueError(
|
||||
f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}."
|
||||
)
|
||||
|
||||
if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph:
|
||||
print("[I] CUDA Graph is disabled since dynamic input shape is configured.")
|
||||
args.disable_cuda_graph = True
|
||||
|
||||
if args.onnx_opset is None:
|
||||
args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17
|
||||
|
||||
if is_xl:
|
||||
if args.version == "xl-turbo":
|
||||
if args.lcm:
|
||||
print("[I] sdxl-turbo cannot use with LCM.")
|
||||
args.lcm = False
|
||||
|
||||
assert args.strength > 0.0 and args.strength < 1.0
|
||||
|
||||
assert not (args.lcm and args.lora_weights), "it is not supported to use both lcm unet and Lora together"
|
||||
|
||||
if args.scheduler == "LCM":
|
||||
if args.guidance > 2.0:
|
||||
print("[I] Use --guidance=0.0 (no more than 2.0) when LCM scheduler is used.")
|
||||
args.guidance = 0.0
|
||||
if args.denoising_steps > 16:
|
||||
print("[I] Use --denoising_steps=8 (no more than 16) when LCM scheduler is used.")
|
||||
args.denoising_steps = 8
|
||||
|
||||
print(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def max_batch(args):
|
||||
if args.max_batch_size:
|
||||
max_batch_size = args.max_batch_size
|
||||
else:
|
||||
do_classifier_free_guidance = args.guidance > 1.0
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
max_batch_size = 32 // batch_multiplier
|
||||
if args.engine != "ORT_CUDA" and (args.build_dynamic_shape or args.height > 512 or args.width > 512):
|
||||
max_batch_size = 8 // batch_multiplier
|
||||
return max_batch_size
|
||||
|
||||
|
||||
def get_metadata(args, is_xl: bool = False) -> dict[str, Any]:
|
||||
metadata = {
|
||||
"command": " ".join(['"' + x + '"' if " " in x else x for x in sys.argv]),
|
||||
"args.prompt": args.prompt,
|
||||
"args.negative_prompt": args.negative_prompt,
|
||||
"args.batch_size": args.batch_size,
|
||||
"height": args.height,
|
||||
"width": args.width,
|
||||
"cuda_graph": not args.disable_cuda_graph,
|
||||
"vae_slicing": args.enable_vae_slicing,
|
||||
"engine": args.engine,
|
||||
}
|
||||
|
||||
if args.lora_weights:
|
||||
metadata["lora_weights"] = args.lora_weights
|
||||
metadata["lora_scale"] = args.lora_scale
|
||||
|
||||
if args.controlnet_type:
|
||||
metadata["controlnet_type"] = args.controlnet_type
|
||||
metadata["controlnet_scale"] = args.controlnet_scale
|
||||
|
||||
if is_xl and args.enable_refiner:
|
||||
metadata["base.scheduler"] = args.scheduler
|
||||
metadata["base.denoising_steps"] = args.denoising_steps
|
||||
metadata["base.guidance"] = args.guidance
|
||||
metadata["refiner.strength"] = args.strength
|
||||
metadata["refiner.scheduler"] = args.refiner_scheduler
|
||||
metadata["refiner.denoising_steps"] = args.refiner_denoising_steps
|
||||
metadata["refiner.guidance"] = args.refiner_guidance
|
||||
else:
|
||||
metadata["scheduler"] = args.scheduler
|
||||
metadata["denoising_steps"] = args.denoising_steps
|
||||
metadata["guidance"] = args.guidance
|
||||
|
||||
# Version of installed python packages
|
||||
packages = ""
|
||||
for name in [
|
||||
"onnxruntime-gpu",
|
||||
"torch",
|
||||
"tensorrt",
|
||||
"transformers",
|
||||
"diffusers",
|
||||
"onnx",
|
||||
"onnx-graphsurgeon",
|
||||
"polygraphy",
|
||||
"controlnet_aux",
|
||||
]:
|
||||
try:
|
||||
packages += (" " if packages else "") + f"{name}=={version(name)}"
|
||||
except PackageNotFoundError:
|
||||
continue
|
||||
metadata["packages"] = packages
|
||||
metadata["device"] = torch.cuda.get_device_name()
|
||||
metadata["torch.version.cuda"] = torch.version.cuda
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def repeat_prompt(args):
|
||||
if not isinstance(args.prompt, list):
|
||||
raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}")
|
||||
prompt = args.prompt * args.batch_size
|
||||
|
||||
if not isinstance(args.negative_prompt, list):
|
||||
raise ValueError(
|
||||
f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}"
|
||||
)
|
||||
|
||||
if len(args.negative_prompt) == 1:
|
||||
negative_prompt = args.negative_prompt * len(prompt)
|
||||
else:
|
||||
negative_prompt = args.negative_prompt
|
||||
|
||||
return prompt, negative_prompt
|
||||
|
||||
|
||||
def initialize_pipeline(
|
||||
version="xl-turbo",
|
||||
is_refiner: bool = False,
|
||||
is_inpaint: bool = False,
|
||||
engine_type=EngineType.ORT_CUDA,
|
||||
work_dir: str = ".",
|
||||
engine_dir=None,
|
||||
onnx_opset: int = 17,
|
||||
scheduler="EulerA",
|
||||
height=512,
|
||||
width=512,
|
||||
nvtx_profile=False,
|
||||
use_cuda_graph=True,
|
||||
build_dynamic_batch=False,
|
||||
build_dynamic_shape=False,
|
||||
min_image_size: int = 512,
|
||||
max_image_size: int = 1024,
|
||||
max_batch_size: int = 16,
|
||||
opt_batch_size: int = 1,
|
||||
build_all_tactics: bool = False,
|
||||
do_classifier_free_guidance: bool = False,
|
||||
lcm: bool = False,
|
||||
controlnet=None,
|
||||
lora_weights=None,
|
||||
lora_scale: float = 1.0,
|
||||
use_fp16_vae: bool = True,
|
||||
use_vae: bool = True,
|
||||
framework_model_dir: str | None = None,
|
||||
max_cuda_graphs: int = 1,
|
||||
):
|
||||
pipeline_info = PipelineInfo(
|
||||
version,
|
||||
is_refiner=is_refiner,
|
||||
is_inpaint=is_inpaint,
|
||||
use_vae=use_vae,
|
||||
min_image_size=min_image_size,
|
||||
max_image_size=max_image_size,
|
||||
use_fp16_vae=use_fp16_vae,
|
||||
use_lcm=lcm,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
controlnet=controlnet,
|
||||
lora_weights=lora_weights,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
input_engine_dir = engine_dir
|
||||
|
||||
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
|
||||
work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type, framework_model_dir=framework_model_dir
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
pipeline_info,
|
||||
scheduler=scheduler,
|
||||
output_dir=output_dir,
|
||||
verbose=False,
|
||||
nvtx_profile=nvtx_profile,
|
||||
max_batch_size=max_batch_size,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
framework_model_dir=framework_model_dir,
|
||||
engine_type=engine_type,
|
||||
)
|
||||
|
||||
import_engine_dir = None
|
||||
if input_engine_dir:
|
||||
if not os.path.exists(input_engine_dir):
|
||||
raise RuntimeError(f"--engine_dir directory does not exist: {input_engine_dir}")
|
||||
|
||||
# Support importing from optimized diffusers onnx pipeline
|
||||
if engine_type == EngineType.ORT_CUDA and os.path.exists(os.path.join(input_engine_dir, "model_index.json")):
|
||||
import_engine_dir = input_engine_dir
|
||||
else:
|
||||
engine_dir = input_engine_dir
|
||||
|
||||
opt_image_height = pipeline_info.default_image_size() if build_dynamic_shape else height
|
||||
opt_image_width = pipeline_info.default_image_size() if build_dynamic_shape else width
|
||||
|
||||
if engine_type == EngineType.ORT_CUDA:
|
||||
pipeline.backend.build_engines(
|
||||
engine_dir=engine_dir,
|
||||
framework_model_dir=framework_model_dir,
|
||||
onnx_dir=onnx_dir,
|
||||
tmp_dir=os.path.join(work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"),
|
||||
device_id=torch.cuda.current_device(),
|
||||
import_engine_dir=import_engine_dir,
|
||||
max_cuda_graphs=max_cuda_graphs,
|
||||
)
|
||||
elif engine_type == EngineType.ORT_TRT:
|
||||
pipeline.backend.build_engines(
|
||||
engine_dir,
|
||||
framework_model_dir,
|
||||
onnx_dir,
|
||||
onnx_opset,
|
||||
opt_image_height=opt_image_height,
|
||||
opt_image_width=opt_image_width,
|
||||
opt_batch_size=opt_batch_size,
|
||||
static_batch=not build_dynamic_batch,
|
||||
static_image_shape=not build_dynamic_shape,
|
||||
max_workspace_size=0,
|
||||
device_id=torch.cuda.current_device(),
|
||||
timing_cache=timing_cache,
|
||||
)
|
||||
elif engine_type == EngineType.TRT:
|
||||
pipeline.backend.load_engines(
|
||||
engine_dir,
|
||||
framework_model_dir,
|
||||
onnx_dir,
|
||||
onnx_opset,
|
||||
opt_batch_size=opt_batch_size,
|
||||
opt_image_height=opt_image_height,
|
||||
opt_image_width=opt_image_width,
|
||||
static_batch=not build_dynamic_batch,
|
||||
static_shape=not build_dynamic_shape,
|
||||
enable_all_tactics=build_all_tactics,
|
||||
timing_cache=timing_cache,
|
||||
)
|
||||
elif engine_type == EngineType.TORCH:
|
||||
pipeline.backend.build_engines(framework_model_dir)
|
||||
else:
|
||||
raise RuntimeError("invalid engine type")
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def load_pipelines(args, batch_size=None):
|
||||
engine_type = get_engine_type(args.engine)
|
||||
|
||||
# Register TensorRT plugins
|
||||
if engine_type == EngineType.TRT:
|
||||
from trt_utilities import init_trt_plugins # noqa: PLC0415
|
||||
|
||||
init_trt_plugins()
|
||||
|
||||
max_batch_size = max_batch(args)
|
||||
|
||||
if batch_size is None:
|
||||
assert isinstance(args.prompt, list)
|
||||
batch_size = len(args.prompt) * args.batch_size
|
||||
|
||||
if batch_size > max_batch_size:
|
||||
raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.")
|
||||
|
||||
# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
|
||||
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
|
||||
# This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
|
||||
if args.version == "xl-turbo":
|
||||
min_image_size = 512
|
||||
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
|
||||
elif args.version == "xl-1.0":
|
||||
min_image_size = 832 if args.engine != "ORT_CUDA" else 512
|
||||
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048
|
||||
else:
|
||||
# This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768.
|
||||
min_image_size = 512 if args.engine != "ORT_CUDA" else 256
|
||||
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
|
||||
|
||||
params = {
|
||||
"version": args.version,
|
||||
"is_refiner": False,
|
||||
"is_inpaint": False,
|
||||
"engine_type": engine_type,
|
||||
"work_dir": args.work_dir,
|
||||
"engine_dir": args.engine_dir,
|
||||
"onnx_opset": args.onnx_opset,
|
||||
"scheduler": args.scheduler,
|
||||
"height": args.height,
|
||||
"width": args.width,
|
||||
"nvtx_profile": args.nvtx_profile,
|
||||
"use_cuda_graph": not args.disable_cuda_graph,
|
||||
"build_dynamic_batch": args.build_dynamic_batch,
|
||||
"build_dynamic_shape": args.build_dynamic_shape,
|
||||
"min_image_size": min_image_size,
|
||||
"max_image_size": max_image_size,
|
||||
"max_batch_size": max_batch_size,
|
||||
"opt_batch_size": 1 if args.build_dynamic_batch else batch_size,
|
||||
"build_all_tactics": args.build_all_tactics,
|
||||
"do_classifier_free_guidance": args.guidance > 1.0,
|
||||
"controlnet": args.controlnet_type,
|
||||
"lora_weights": args.lora_weights,
|
||||
"lora_scale": args.lora_scale,
|
||||
"use_fp16_vae": "xl" in args.version,
|
||||
"use_vae": True,
|
||||
"framework_model_dir": args.framework_model_dir,
|
||||
"max_cuda_graphs": args.max_cuda_graphs,
|
||||
}
|
||||
|
||||
if "xl" in args.version:
|
||||
params["lcm"] = args.lcm
|
||||
params["use_vae"] = not args.enable_refiner
|
||||
base = initialize_pipeline(**params)
|
||||
|
||||
refiner = None
|
||||
if "xl" in args.version and args.enable_refiner:
|
||||
params["version"] = "xl-1.0" # Allow SDXL Turbo to use refiner.
|
||||
params["is_refiner"] = True
|
||||
params["scheduler"] = args.refiner_scheduler
|
||||
params["do_classifier_free_guidance"] = args.refiner_guidance > 1.0
|
||||
params["lcm"] = False
|
||||
params["controlnet"] = None
|
||||
params["lora_weights"] = None
|
||||
params["use_vae"] = True
|
||||
params["use_fp16_vae"] = True
|
||||
refiner = initialize_pipeline(**params)
|
||||
|
||||
if engine_type == EngineType.TRT:
|
||||
max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory())
|
||||
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
|
||||
base.backend.activate_engines(shared_device_memory)
|
||||
if refiner:
|
||||
refiner.backend.activate_engines(shared_device_memory)
|
||||
|
||||
if engine_type == EngineType.ORT_CUDA:
|
||||
enable_vae_slicing = args.enable_vae_slicing
|
||||
if batch_size > 4 and not enable_vae_slicing and (args.height >= 1024 and args.width >= 1024):
|
||||
print(
|
||||
"Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4 and resolution >= 1024."
|
||||
)
|
||||
enable_vae_slicing = True
|
||||
if enable_vae_slicing:
|
||||
(refiner or base).backend.enable_vae_slicing()
|
||||
return base, refiner
|
||||
|
||||
|
||||
def get_depth_image(image):
|
||||
"""
|
||||
Create depth map for SDXL depth control net.
|
||||
"""
|
||||
from transformers import DPTFeatureExtractor, DPTForDepthEstimation # noqa: PLC0415
|
||||
|
||||
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
||||
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
||||
|
||||
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
||||
with torch.no_grad(), torch.autocast("cuda"):
|
||||
depth_map = depth_estimator(image).predicted_depth
|
||||
|
||||
# The depth map is 384x384 by default, here we interpolate to the default output size.
|
||||
# Note that it will be resized to output image size later. May change the size here to avoid interpolate twice.
|
||||
depth_map = torch.nn.functional.interpolate(
|
||||
depth_map.unsqueeze(1),
|
||||
size=(1024, 1024),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
||||
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
||||
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
||||
image = torch.cat([depth_map] * 3, dim=1)
|
||||
|
||||
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
||||
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
||||
return image
|
||||
|
||||
|
||||
def get_canny_image(image) -> Image.Image:
|
||||
"""
|
||||
Create canny image for SDXL control net.
|
||||
"""
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
image = Image.fromarray(image)
|
||||
return image
|
||||
|
||||
|
||||
def process_controlnet_images_xl(args) -> list[Image.Image]:
|
||||
"""
|
||||
Process control image for SDXL control net.
|
||||
"""
|
||||
assert len(args.controlnet_image) == 1
|
||||
image = Image.open(args.controlnet_image[0]).convert("RGB")
|
||||
|
||||
controlnet_images = []
|
||||
if args.controlnet_type[0] == "canny":
|
||||
controlnet_images.append(get_canny_image(image))
|
||||
elif args.controlnet_type[0] == "depth":
|
||||
controlnet_images.append(get_depth_image(image))
|
||||
else:
|
||||
raise ValueError(f"This controlnet type is not supported for SDXL or Turbo: {args.controlnet_type}.")
|
||||
|
||||
return controlnet_images
|
||||
|
||||
|
||||
def add_controlnet_arguments(parser, is_xl: bool = False):
|
||||
"""
|
||||
Add control net related arguments.
|
||||
"""
|
||||
group = parser.add_argument_group("Options for ControlNet (supports 1.5, sd-turbo, xl-turbo, xl-1.0).")
|
||||
|
||||
group.add_argument(
|
||||
"-ci",
|
||||
"--controlnet-image",
|
||||
nargs="*",
|
||||
type=str,
|
||||
default=[],
|
||||
help="Path to the input regular RGB image/images for controlnet",
|
||||
)
|
||||
group.add_argument(
|
||||
"-ct",
|
||||
"--controlnet-type",
|
||||
nargs="*",
|
||||
type=str,
|
||||
default=[],
|
||||
choices=list(PipelineInfo.supported_controlnet("xl-1.0" if is_xl else "1.5").keys()),
|
||||
help="A list of controlnet type",
|
||||
)
|
||||
group.add_argument(
|
||||
"-cs",
|
||||
"--controlnet-scale",
|
||||
nargs="*",
|
||||
type=float,
|
||||
default=[],
|
||||
help="The outputs of the controlnet are multiplied by `controlnet_scale` before they are added to the residual in the original unet. Default is 0.5 for SDXL, or 1.0 for SD 1.5",
|
||||
)
|
||||
|
||||
|
||||
def process_controlnet_image(controlnet_type: str, image: Image.Image, height, width):
|
||||
"""
|
||||
Process control images of control net v1.1 for Stable Diffusion 1.5.
|
||||
"""
|
||||
control_image = None
|
||||
shape = (height, width)
|
||||
image = image.convert("RGB")
|
||||
if controlnet_type == "canny":
|
||||
canny_image = controlnet_aux.CannyDetector()(image)
|
||||
control_image = canny_image.resize(shape)
|
||||
elif controlnet_type == "normalbae":
|
||||
normal_image = controlnet_aux.NormalBaeDetector.from_pretrained("lllyasviel/Annotators")(image)
|
||||
control_image = normal_image.resize(shape)
|
||||
elif controlnet_type == "depth":
|
||||
depth_image = controlnet_aux.LeresDetector.from_pretrained("lllyasviel/Annotators")(image)
|
||||
control_image = depth_image.resize(shape)
|
||||
elif controlnet_type == "mlsd":
|
||||
mlsd_image = controlnet_aux.MLSDdetector.from_pretrained("lllyasviel/Annotators")(image)
|
||||
control_image = mlsd_image.resize(shape)
|
||||
elif controlnet_type == "openpose":
|
||||
openpose_image = controlnet_aux.OpenposeDetector.from_pretrained("lllyasviel/Annotators")(image)
|
||||
control_image = openpose_image.resize(shape)
|
||||
elif controlnet_type == "scribble":
|
||||
scribble_image = controlnet_aux.HEDdetector.from_pretrained("lllyasviel/Annotators")(image, scribble=True)
|
||||
control_image = scribble_image.resize(shape)
|
||||
elif controlnet_type == "seg":
|
||||
seg_image = controlnet_aux.SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")(
|
||||
image
|
||||
)
|
||||
control_image = seg_image.resize(shape)
|
||||
else:
|
||||
raise ValueError(f"There is no demo image of this controlnet_type: {controlnet_type}")
|
||||
return control_image
|
||||
|
||||
|
||||
def process_controlnet_arguments(args):
|
||||
"""
|
||||
Process control net arguments, and returns a list of control images and a tensor of control net scales.
|
||||
"""
|
||||
assert isinstance(args.controlnet_type, list)
|
||||
assert isinstance(args.controlnet_scale, list)
|
||||
assert isinstance(args.controlnet_image, list)
|
||||
|
||||
if len(args.controlnet_image) != len(args.controlnet_type):
|
||||
raise ValueError(
|
||||
f"Numbers of controlnet_image {len(args.controlnet_image)} should be equal to number of controlnet_type {len(args.controlnet_type)}."
|
||||
)
|
||||
|
||||
if len(args.controlnet_type) == 0:
|
||||
return None, None
|
||||
|
||||
if args.version not in ["1.5", "xl-1.0", "xl-turbo", "sd-turbo"]:
|
||||
raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5, XL or Turbo.")
|
||||
|
||||
is_xl = "xl" in args.version
|
||||
if is_xl and len(args.controlnet_type) > 1:
|
||||
raise ValueError("This demo only support one ControlNet for Stable Diffusion XL or Turbo.")
|
||||
|
||||
if len(args.controlnet_scale) == 0:
|
||||
args.controlnet_scale = [0.5 if is_xl else 1.0] * len(args.controlnet_type)
|
||||
elif len(args.controlnet_type) != len(args.controlnet_scale):
|
||||
raise ValueError(
|
||||
f"Numbers of controlnet_type {len(args.controlnet_type)} should be equal to number of controlnet_scale {len(args.controlnet_scale)}."
|
||||
)
|
||||
|
||||
# Convert controlnet scales to tensor
|
||||
controlnet_scale = torch.FloatTensor(args.controlnet_scale)
|
||||
|
||||
if is_xl:
|
||||
images = process_controlnet_images_xl(args)
|
||||
else:
|
||||
images = []
|
||||
for i, image in enumerate(args.controlnet_image):
|
||||
images.append(process_controlnet_image(args.controlnet_type[i], Image.open(image), args.height, args.width))
|
||||
|
||||
return images, controlnet_scale
|
||||
+1318
File diff suppressed because it is too large
Load Diff
+1179
File diff suppressed because it is too large
Load Diff
+295
@@ -0,0 +1,295 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import hashlib
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL
|
||||
|
||||
|
||||
class EngineType(Enum):
|
||||
ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider
|
||||
ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider
|
||||
TRT = 2 # TensorRT
|
||||
TORCH = 3 # PyTorch
|
||||
|
||||
|
||||
def get_engine_type(name: str) -> EngineType:
|
||||
name_to_type = {
|
||||
"ORT_CUDA": EngineType.ORT_CUDA,
|
||||
"ORT_TRT": EngineType.ORT_TRT,
|
||||
"TRT": EngineType.TRT,
|
||||
"TORCH": EngineType.TORCH,
|
||||
}
|
||||
return name_to_type[name]
|
||||
|
||||
|
||||
class EngineBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
engine_type: EngineType,
|
||||
pipeline_info: PipelineInfo,
|
||||
device="cuda",
|
||||
max_batch_size=16,
|
||||
use_cuda_graph=False,
|
||||
):
|
||||
"""
|
||||
Initializes the Engine Builder.
|
||||
|
||||
Args:
|
||||
pipeline_info (PipelineInfo):
|
||||
Version and Type of pipeline.
|
||||
device (str | torch.device):
|
||||
device to run engine
|
||||
max_batch_size (int):
|
||||
Maximum batch size for dynamic batch engine.
|
||||
use_cuda_graph (bool):
|
||||
Use CUDA graph to capture engine execution and then launch inference
|
||||
"""
|
||||
self.engine_type = engine_type
|
||||
self.pipeline_info = pipeline_info
|
||||
self.max_batch_size = max_batch_size
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.device = torch.device(device)
|
||||
self.torch_device = torch.device(device, torch.cuda.current_device())
|
||||
self.stages = pipeline_info.stages()
|
||||
|
||||
self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() and self.engine_type != EngineType.TORCH
|
||||
self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae()
|
||||
|
||||
self.models = {}
|
||||
self.engines = {}
|
||||
self.torch_models = {}
|
||||
self.use_vae_slicing = False
|
||||
|
||||
self.torch_sdpa = getattr(torch.nn.functional, "scaled_dot_product_attention", None)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
self.use_vae_slicing = True
|
||||
|
||||
def disable_torch_spda(self):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
delattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||
|
||||
def enable_torch_spda(self):
|
||||
if (not hasattr(torch.nn.functional, "scaled_dot_product_attention")) and self.torch_sdpa:
|
||||
torch.nn.functional.scaled_dot_product_attention = self.torch_sdpa
|
||||
|
||||
def teardown(self):
|
||||
for engine in self.engines.values():
|
||||
del engine
|
||||
self.engines = {}
|
||||
|
||||
def get_diffusers_module_name(self, model_name):
|
||||
name_mapping = {
|
||||
"clip": "text_encoder",
|
||||
"clip2": "text_encoder_2",
|
||||
"unet": "unet",
|
||||
"unetxl": "unet",
|
||||
"vae": "vae_decoder",
|
||||
}
|
||||
return name_mapping.get(model_name, model_name)
|
||||
|
||||
def get_cached_model_name(self, model_name):
|
||||
model_name = self.get_diffusers_module_name(model_name)
|
||||
is_unet = model_name == "unet"
|
||||
hash_source = []
|
||||
if model_name in ["text_encoder", "text_encoder_2", "unet"] and self.pipeline_info.lora_weights:
|
||||
if self.pipeline_info.lora_weights in [
|
||||
"latent-consistency/lcm-lora-sdxl",
|
||||
"latent-consistency/lcm-lora-sdv1-5",
|
||||
]:
|
||||
if is_unet:
|
||||
model_name = "unet_lcm-lora"
|
||||
else:
|
||||
model_name = model_name + "_lora"
|
||||
hash_source.append(self.pipeline_info.lora_weights)
|
||||
|
||||
# TODO(tianleiwu): save custom model to a directory named by its original model.
|
||||
if is_unet and self.pipeline_info.custom_unet():
|
||||
model_name = model_name + "_lcm"
|
||||
|
||||
if model_name in ["unet"] and self.pipeline_info.controlnet:
|
||||
model_name = model_name + "_" + "_".join(self.pipeline_info.controlnet)
|
||||
|
||||
if hash_source:
|
||||
model_name += "_" + hashlib.sha256("\t".join(hash_source).encode("utf-8")).hexdigest()[:8]
|
||||
|
||||
# TODO: When we support original VAE, we shall save custom VAE to another directory.
|
||||
|
||||
if self.pipeline_info.is_inpaint():
|
||||
model_name += "_inpaint"
|
||||
return model_name
|
||||
|
||||
def get_model_dir(self, model_name, root_dir, opt=True, suffix="", create=True):
|
||||
engine_name = self.engine_type.name.lower()
|
||||
if engine_name != "ort_cuda" and not suffix:
|
||||
suffix = f".{engine_name}" if opt else ""
|
||||
directory_name = self.get_cached_model_name(model_name) + suffix
|
||||
onnx_model_dir = os.path.join(root_dir, directory_name)
|
||||
if create:
|
||||
os.makedirs(onnx_model_dir, exist_ok=True)
|
||||
return onnx_model_dir
|
||||
|
||||
def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""):
|
||||
onnx_model_dir = self.get_model_dir(model_name, onnx_dir, opt=opt, suffix=suffix)
|
||||
return os.path.join(onnx_model_dir, "model.onnx")
|
||||
|
||||
def get_engine_path(self, engine_dir, model_name, profile_id):
|
||||
return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id)
|
||||
|
||||
def load_pipeline_with_lora(self):
|
||||
"""Load text encoders and UNet with diffusers pipeline"""
|
||||
from diffusers import DiffusionPipeline # noqa: PLC0415
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
self.pipeline_info.name(),
|
||||
variant="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipeline.load_lora_weights(self.pipeline_info.lora_weights)
|
||||
pipeline.fuse_lora(lora_scale=self.pipeline_info.lora_scale)
|
||||
|
||||
del pipeline.vae
|
||||
pipeline.vae = None
|
||||
return pipeline
|
||||
|
||||
def get_or_load_model(self, pipeline, model_name, model_obj, framework_model_dir):
|
||||
if model_name in ["clip", "clip2", "unet", "unetxl"] and pipeline:
|
||||
if model_name == "clip":
|
||||
model = pipeline.text_encoder
|
||||
pipeline.text_encoder = None
|
||||
elif model_name == "clip2":
|
||||
model = pipeline.text_encoder_2
|
||||
pipeline.text_encoder_2 = None
|
||||
else:
|
||||
model = pipeline.unet
|
||||
pipeline.unet = None
|
||||
else:
|
||||
model = model_obj.load_model(framework_model_dir)
|
||||
|
||||
return model.to(self.torch_device)
|
||||
|
||||
def load_models(self, framework_model_dir: str):
|
||||
# For TRT or ORT_TRT, we will export fp16 torch model for UNet and VAE
|
||||
# For ORT_CUDA, we export fp32 model first, then optimize to fp16.
|
||||
export_fp16 = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT]
|
||||
|
||||
if "clip" in self.stages:
|
||||
self.models["clip"] = CLIP(
|
||||
self.pipeline_info,
|
||||
None, # not loaded yet
|
||||
device=self.torch_device,
|
||||
max_batch_size=self.max_batch_size,
|
||||
clip_skip=0,
|
||||
)
|
||||
|
||||
if "clip2" in self.stages:
|
||||
self.models["clip2"] = CLIPWithProj(
|
||||
self.pipeline_info,
|
||||
None, # not loaded yet
|
||||
device=self.torch_device,
|
||||
max_batch_size=self.max_batch_size,
|
||||
clip_skip=0,
|
||||
)
|
||||
|
||||
if "unet" in self.stages:
|
||||
self.models["unet"] = UNet(
|
||||
self.pipeline_info,
|
||||
None, # not loaded yet
|
||||
device=self.torch_device,
|
||||
fp16=export_fp16,
|
||||
max_batch_size=self.max_batch_size,
|
||||
unet_dim=(9 if self.pipeline_info.is_inpaint() else 4),
|
||||
)
|
||||
|
||||
if "unetxl" in self.stages:
|
||||
self.models["unetxl"] = UNetXL(
|
||||
self.pipeline_info,
|
||||
None, # not loaded yet
|
||||
device=self.torch_device,
|
||||
fp16=export_fp16,
|
||||
max_batch_size=self.max_batch_size,
|
||||
unet_dim=4,
|
||||
time_dim=(5 if self.pipeline_info.is_xl_refiner() else 6),
|
||||
)
|
||||
|
||||
# VAE Decoder
|
||||
if "vae" in self.stages:
|
||||
self.models["vae"] = VAE(
|
||||
self.pipeline_info,
|
||||
None, # not loaded yet
|
||||
device=self.torch_device,
|
||||
max_batch_size=self.max_batch_size,
|
||||
fp16=export_fp16,
|
||||
custom_fp16_vae=self.custom_fp16_vae,
|
||||
)
|
||||
|
||||
if self.vae_torch_fallback:
|
||||
self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir)
|
||||
|
||||
def load_resources(self, image_height, image_width, batch_size):
|
||||
if self.engine_type == EngineType.TORCH:
|
||||
return
|
||||
|
||||
# Allocate buffers for I/O bindings
|
||||
for model_name, obj in self.models.items():
|
||||
if model_name == "vae" and self.vae_torch_fallback:
|
||||
continue
|
||||
slice_size = 1 if (model_name == "vae" and self.use_vae_slicing) else batch_size
|
||||
self.engines[model_name].allocate_buffers(
|
||||
shape_dict=obj.get_shape_dict(slice_size, image_height, image_width), device=self.torch_device
|
||||
)
|
||||
|
||||
def _vae_decode(self, latents):
|
||||
if self.engine_type == EngineType.TORCH:
|
||||
if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
images = self.engines["vae"](latents)["sample"]
|
||||
else:
|
||||
images = self.engines["vae"](latents)["sample"]
|
||||
elif self.vae_torch_fallback:
|
||||
if not self.custom_fp16_vae:
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32)
|
||||
images = self.torch_models["vae"](latents)["sample"]
|
||||
else:
|
||||
if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
|
||||
images = self.run_engine("vae", {"latent": latents.to(dtype=torch.float32)})["images"]
|
||||
else:
|
||||
images = self.run_engine("vae", {"latent": latents})["images"]
|
||||
|
||||
return images
|
||||
|
||||
def vae_decode(self, latents):
|
||||
if self.use_vae_slicing:
|
||||
# The output tensor points to same buffer. Need clone it to avoid overwritten.
|
||||
decoded_slices = [self._vae_decode(z_slice).clone() for z_slice in latents.split(1)]
|
||||
return torch.cat(decoded_slices)
|
||||
|
||||
return self._vae_decode(latents)
|
||||
|
||||
|
||||
def get_engine_paths(
|
||||
work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: str | None = None
|
||||
):
|
||||
root_dir = work_dir or "."
|
||||
short_name = pipeline_info.short_name()
|
||||
|
||||
# When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since
|
||||
# ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model.
|
||||
onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx")
|
||||
engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine")
|
||||
output_dir = os.path.join(root_dir, engine_type.name, short_name, "output")
|
||||
|
||||
timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache")
|
||||
|
||||
# Shared among ORT_CUDA, ORT_TRT and TRT engines, and need use load_model(..., always_download_fp16=True)
|
||||
# So that the shared model is always fp16.
|
||||
if framework_model_dir is None:
|
||||
framework_model_dir = os.path.join(root_dir, "torch_model")
|
||||
|
||||
return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache
|
||||
+387
@@ -0,0 +1,387 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from diffusion_models import PipelineInfo
|
||||
from engine_builder import EngineBuilder, EngineType
|
||||
from packaging import version
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
|
||||
from onnxruntime.transformers.onnx_model import OnnxModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrtCudaEngine:
|
||||
def __init__(
|
||||
self,
|
||||
onnx_path,
|
||||
device_id: int = 0,
|
||||
enable_cuda_graph: bool = False,
|
||||
disable_optimization: bool = False,
|
||||
max_cuda_graphs: int = 1,
|
||||
):
|
||||
self.onnx_path = onnx_path
|
||||
self.provider = "CUDAExecutionProvider"
|
||||
self.stream = torch.cuda.current_stream().cuda_stream
|
||||
self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph, self.stream)
|
||||
session_options = ort.SessionOptions()
|
||||
|
||||
# When the model has been optimized by onnxruntime, we can disable optimization to save session creation time.
|
||||
if disable_optimization:
|
||||
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
logger.info("creating CUDA EP session for %s", onnx_path)
|
||||
ort_session = ort.InferenceSession(
|
||||
onnx_path,
|
||||
session_options,
|
||||
providers=[
|
||||
(self.provider, self.provider_options),
|
||||
"CPUExecutionProvider",
|
||||
],
|
||||
)
|
||||
logger.info("created CUDA EP session for %s", onnx_path)
|
||||
|
||||
device = torch.device("cuda", device_id)
|
||||
self.enable_cuda_graph = enable_cuda_graph
|
||||
|
||||
# Support multiple CUDA graphs for different input shapes.
|
||||
# For clip2 model that disabled cuda graph, max_cuda_graphs is updated to 0 here.
|
||||
self.gpu_binding_manager = GpuBindingManager(
|
||||
ort_session=ort_session,
|
||||
device=device,
|
||||
stream=self.stream,
|
||||
max_cuda_graphs=max_cuda_graphs if enable_cuda_graph else 0,
|
||||
)
|
||||
|
||||
self.current_gpu_binding = None
|
||||
|
||||
def metadata(self, name: str):
|
||||
data = {}
|
||||
if self.current_gpu_binding is not None:
|
||||
if self.current_gpu_binding.last_run_gpu_graph_id >= 0:
|
||||
data[f"{name}.gpu_graph_id"] = self.current_gpu_binding.last_run_gpu_graph_id
|
||||
return data
|
||||
|
||||
def infer(self, feed_dict: dict[str, torch.Tensor]):
|
||||
return self.current_gpu_binding.infer(feed_dict=feed_dict, disable_cuda_graph_in_run=not self.enable_cuda_graph)
|
||||
|
||||
def allocate_buffers(self, shape_dict, device):
|
||||
self.current_gpu_binding = self.gpu_binding_manager.get_binding(
|
||||
shape_dict=shape_dict, use_cuda_graph=self.enable_cuda_graph
|
||||
)
|
||||
|
||||
|
||||
class _ModelConfig:
|
||||
"""
|
||||
Configuration of one model (like Clip, UNet etc) on ONNX export and optimization for CUDA provider.
|
||||
For example, if you want to use fp32 in layer normalization, set the following:
|
||||
force_fp32_ops=["SkipLayerNormalization", "LayerNormalization"]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
onnx_opset_version: int,
|
||||
use_cuda_graph: bool,
|
||||
fp16: bool = True,
|
||||
force_fp32_ops: list[str] | None = None,
|
||||
optimize_by_ort: bool = True,
|
||||
):
|
||||
self.onnx_opset_version = onnx_opset_version
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.fp16 = fp16
|
||||
self.force_fp32_ops = force_fp32_ops
|
||||
self.optimize_by_ort = optimize_by_ort
|
||||
|
||||
|
||||
class OrtCudaEngineBuilder(EngineBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_info: PipelineInfo,
|
||||
max_batch_size=16,
|
||||
device="cuda",
|
||||
use_cuda_graph=False,
|
||||
):
|
||||
"""
|
||||
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
||||
|
||||
Args:
|
||||
pipeline_info (PipelineInfo):
|
||||
Version and Type of pipeline.
|
||||
max_batch_size (int):
|
||||
Maximum batch size for dynamic batch engine.
|
||||
device (str):
|
||||
device to run.
|
||||
use_cuda_graph (bool):
|
||||
Use CUDA graph to capture engine execution and then launch inference
|
||||
"""
|
||||
super().__init__(
|
||||
EngineType.ORT_CUDA,
|
||||
pipeline_info,
|
||||
max_batch_size=max_batch_size,
|
||||
device=device,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
|
||||
self.model_config = {}
|
||||
|
||||
def _configure(
|
||||
self,
|
||||
model_name: str,
|
||||
onnx_opset_version: int,
|
||||
use_cuda_graph: bool,
|
||||
fp16: bool = True,
|
||||
force_fp32_ops: list[str] | None = None,
|
||||
optimize_by_ort: bool = True,
|
||||
):
|
||||
self.model_config[model_name] = _ModelConfig(
|
||||
onnx_opset_version,
|
||||
use_cuda_graph,
|
||||
fp16=fp16,
|
||||
force_fp32_ops=force_fp32_ops,
|
||||
optimize_by_ort=optimize_by_ort,
|
||||
)
|
||||
|
||||
def configure_xl(self, onnx_opset_version: int):
|
||||
self._configure(
|
||||
"clip",
|
||||
onnx_opset_version=onnx_opset_version,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
self._configure(
|
||||
"clip2",
|
||||
onnx_opset_version=onnx_opset_version, # TODO: ArgMax-12 is not implemented in CUDA
|
||||
use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph
|
||||
)
|
||||
self._configure(
|
||||
"unetxl",
|
||||
onnx_opset_version=onnx_opset_version,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
|
||||
self._configure(
|
||||
"vae",
|
||||
onnx_opset_version=onnx_opset_version,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
|
||||
def optimized_onnx_path(self, engine_dir, model_name):
|
||||
suffix = "" if self.model_config[model_name].fp16 else ".fp32"
|
||||
return self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix)
|
||||
|
||||
def import_diffusers_engine(self, diffusers_onnx_dir: str, engine_dir: str):
|
||||
"""Import optimized onnx models for diffusers from Olive or optimize_pipeline tools.
|
||||
|
||||
Args:
|
||||
diffusers_onnx_dir (str): optimized onnx directory of Olive
|
||||
engine_dir (str): the directory to store imported onnx
|
||||
"""
|
||||
if version.parse(ort.__version__) < version.parse("1.17.0"):
|
||||
print("Skip importing since onnxruntime-gpu version < 1.17.0.")
|
||||
return
|
||||
|
||||
for model_name, model_obj in self.models.items():
|
||||
onnx_import_path = self.optimized_onnx_path(diffusers_onnx_dir, model_name)
|
||||
if not os.path.exists(onnx_import_path):
|
||||
print(f"{onnx_import_path} not existed. Skip importing.")
|
||||
continue
|
||||
|
||||
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
|
||||
if os.path.exists(onnx_opt_path):
|
||||
print(f"{onnx_opt_path} existed. Skip importing.")
|
||||
continue
|
||||
|
||||
if model_name == "vae" and self.pipeline_info.is_xl():
|
||||
print(f"Skip importing VAE since it is not fully compatible with float16: {onnx_import_path}.")
|
||||
continue
|
||||
|
||||
model = OnnxModel(onnx.load(onnx_import_path, load_external_data=True))
|
||||
|
||||
if model_name in ["clip", "clip2"]:
|
||||
hidden_states_per_layer = []
|
||||
for output in model.graph().output:
|
||||
if output.name.startswith("hidden_states."):
|
||||
hidden_states_per_layer.append(output.name)
|
||||
if hidden_states_per_layer:
|
||||
kept_hidden_states = hidden_states_per_layer[-2 - model_obj.clip_skip]
|
||||
model.rename_graph_output(kept_hidden_states, "hidden_states")
|
||||
|
||||
model.rename_graph_output(
|
||||
"last_hidden_state" if model_name == "clip" else "text_embeds", "text_embeddings"
|
||||
)
|
||||
model.prune_graph(
|
||||
["text_embeddings", "hidden_states"] if hidden_states_per_layer else ["text_embeddings"]
|
||||
)
|
||||
|
||||
if model_name == "clip2":
|
||||
model.change_graph_input_type(model.find_graph_input("input_ids"), onnx.TensorProto.INT32)
|
||||
|
||||
model.save_model_to_file(onnx_opt_path, use_external_data_format=(model_name == "clip2"))
|
||||
elif model_name in ["unet", "unetxl"]:
|
||||
model.rename_graph_output("out_sample", "latent")
|
||||
model.save_model_to_file(onnx_opt_path, use_external_data_format=True)
|
||||
|
||||
del model
|
||||
continue
|
||||
|
||||
def build_engines(
|
||||
self,
|
||||
engine_dir: str,
|
||||
framework_model_dir: str,
|
||||
onnx_dir: str,
|
||||
tmp_dir: str | None = None,
|
||||
onnx_opset_version: int = 17,
|
||||
device_id: int = 0,
|
||||
save_fp32_intermediate_model: bool = False,
|
||||
import_engine_dir: str | None = None,
|
||||
max_cuda_graphs: int = 1,
|
||||
):
|
||||
self.torch_device = torch.device("cuda", device_id)
|
||||
self.load_models(framework_model_dir)
|
||||
|
||||
if not os.path.isdir(engine_dir):
|
||||
os.makedirs(engine_dir)
|
||||
|
||||
if not os.path.isdir(onnx_dir):
|
||||
os.makedirs(onnx_dir)
|
||||
|
||||
# Add default configuration if missing
|
||||
if self.pipeline_info.is_xl():
|
||||
self.configure_xl(onnx_opset_version)
|
||||
for model_name in self.models:
|
||||
if model_name not in self.model_config:
|
||||
self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph)
|
||||
|
||||
# Import Engine
|
||||
if import_engine_dir:
|
||||
if self.pipeline_info.is_xl():
|
||||
self.import_diffusers_engine(import_engine_dir, engine_dir)
|
||||
else:
|
||||
print(f"Only support importing SDXL onnx. Ignore --engine-dir {import_engine_dir}")
|
||||
|
||||
# Load lora only when we need export text encoder or UNet to ONNX.
|
||||
load_lora = False
|
||||
if self.pipeline_info.lora_weights:
|
||||
for model_name in self.models:
|
||||
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
|
||||
continue
|
||||
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
if not os.path.exists(onnx_path):
|
||||
load_lora = True
|
||||
break
|
||||
|
||||
# Export models to ONNX
|
||||
self.disable_torch_spda()
|
||||
pipe = self.load_pipeline_with_lora() if load_lora else None
|
||||
|
||||
for model_name, model_obj in self.models.items():
|
||||
if model_name == "vae" and self.vae_torch_fallback:
|
||||
continue
|
||||
|
||||
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
if not os.path.exists(onnx_path):
|
||||
print("----")
|
||||
logger.info("Exporting model: %s", onnx_path)
|
||||
|
||||
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
||||
model = model.to(torch.float32)
|
||||
|
||||
with torch.inference_mode():
|
||||
# For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern.
|
||||
# Export model with sample of batch size 1, image size 512 x 512
|
||||
inputs = model_obj.get_sample_input(1, 512, 512)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=self.model_config[model_name].onnx_opset_version,
|
||||
do_constant_folding=True,
|
||||
input_names=model_obj.get_input_names(),
|
||||
output_names=model_obj.get_output_names(),
|
||||
dynamic_axes=model_obj.get_dynamic_axes(),
|
||||
)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
else:
|
||||
logger.info("Found cached model: %s", onnx_path)
|
||||
|
||||
# Generate fp32 optimized model.
|
||||
# If final target is fp16 model, we save fp32 optimized model so that it is easy to tune
|
||||
# fp16 conversion. That could save a lot of time in developing.
|
||||
use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16
|
||||
onnx_fp32_path = onnx_path
|
||||
if use_fp32_intermediate:
|
||||
onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32")
|
||||
if not os.path.exists(onnx_fp32_path):
|
||||
print("------")
|
||||
logger.info("Generating optimized model: %s", onnx_fp32_path)
|
||||
model_obj.optimize_ort(
|
||||
onnx_path,
|
||||
onnx_fp32_path,
|
||||
to_fp16=False,
|
||||
fp32_op_list=self.model_config[model_name].force_fp32_ops,
|
||||
optimize_by_ort=self.model_config[model_name].optimize_by_ort,
|
||||
tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".fp32", create=False),
|
||||
)
|
||||
else:
|
||||
logger.info("Found cached optimized model: %s", onnx_fp32_path)
|
||||
|
||||
# Generate the final optimized model.
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
print("------")
|
||||
logger.info("Generating optimized model: %s", onnx_opt_path)
|
||||
|
||||
# When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16.
|
||||
optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort
|
||||
|
||||
model_obj.optimize_ort(
|
||||
onnx_fp32_path,
|
||||
onnx_opt_path,
|
||||
to_fp16=self.model_config[model_name].fp16,
|
||||
fp32_op_list=self.model_config[model_name].force_fp32_ops,
|
||||
optimize_by_ort=optimize_by_ort,
|
||||
optimize_by_fusion=not use_fp32_intermediate,
|
||||
tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".ort", create=False),
|
||||
)
|
||||
else:
|
||||
logger.info("Found cached optimized model: %s", onnx_opt_path)
|
||||
self.enable_torch_spda()
|
||||
|
||||
built_engines = {}
|
||||
for model_name in self.models:
|
||||
if model_name == "vae" and self.vae_torch_fallback:
|
||||
continue
|
||||
|
||||
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
|
||||
use_cuda_graph = self.model_config[model_name].use_cuda_graph
|
||||
|
||||
engine = OrtCudaEngine(
|
||||
onnx_opt_path,
|
||||
device_id=device_id,
|
||||
enable_cuda_graph=use_cuda_graph,
|
||||
disable_optimization=False,
|
||||
max_cuda_graphs=max_cuda_graphs,
|
||||
)
|
||||
|
||||
logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options)
|
||||
built_engines[model_name] = engine
|
||||
|
||||
self.engines = built_engines
|
||||
|
||||
def run_engine(self, model_name, feed_dict):
|
||||
return self.engines[model_name].infer(feed_dict)
|
||||
+288
@@ -0,0 +1,288 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from cuda import cudart
|
||||
from diffusion_models import PipelineInfo
|
||||
from engine_builder import EngineBuilder, EngineType
|
||||
from packaging import version
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrtTensorrtEngine(CudaSession):
|
||||
def __init__(
|
||||
self,
|
||||
engine_path,
|
||||
device_id,
|
||||
onnx_path,
|
||||
fp16,
|
||||
input_profile,
|
||||
workspace_size,
|
||||
enable_cuda_graph,
|
||||
timing_cache_path=None,
|
||||
):
|
||||
self.engine_path = engine_path
|
||||
self.ort_trt_provider_options = self.get_tensorrt_provider_options(
|
||||
input_profile,
|
||||
workspace_size,
|
||||
fp16,
|
||||
device_id,
|
||||
enable_cuda_graph,
|
||||
timing_cache_path=timing_cache_path,
|
||||
)
|
||||
|
||||
session_options = ort.SessionOptions()
|
||||
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
logger.info("creating TRT EP session for %s", onnx_path)
|
||||
ort_session = ort.InferenceSession(
|
||||
onnx_path,
|
||||
session_options,
|
||||
providers=[
|
||||
("TensorrtExecutionProvider", self.ort_trt_provider_options),
|
||||
],
|
||||
)
|
||||
logger.info("created TRT EP session for %s", onnx_path)
|
||||
|
||||
device = torch.device("cuda", device_id)
|
||||
super().__init__(ort_session, device, enable_cuda_graph)
|
||||
|
||||
def get_tensorrt_provider_options(
|
||||
self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph, timing_cache_path=None
|
||||
):
|
||||
trt_ep_options = {
|
||||
"device_id": device_id,
|
||||
"trt_fp16_enable": fp16,
|
||||
"trt_engine_cache_enable": True,
|
||||
"trt_timing_cache_enable": True,
|
||||
"trt_detailed_build_log": True,
|
||||
"trt_engine_cache_path": self.engine_path,
|
||||
}
|
||||
|
||||
if version.parse(ort.__version__) > version.parse("1.16.2") and timing_cache_path is not None:
|
||||
trt_ep_options["trt_timing_cache_path"] = timing_cache_path
|
||||
|
||||
if enable_cuda_graph:
|
||||
trt_ep_options["trt_cuda_graph_enable"] = True
|
||||
|
||||
if workspace_size > 0:
|
||||
trt_ep_options["trt_max_workspace_size"] = workspace_size
|
||||
|
||||
if input_profile:
|
||||
min_shapes = []
|
||||
max_shapes = []
|
||||
opt_shapes = []
|
||||
for name, profile in input_profile.items():
|
||||
assert isinstance(profile, list) and len(profile) == 3
|
||||
min_shape = profile[0]
|
||||
opt_shape = profile[1]
|
||||
max_shape = profile[2]
|
||||
assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape)
|
||||
|
||||
min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape]))
|
||||
opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape]))
|
||||
max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape]))
|
||||
|
||||
trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes)
|
||||
trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes)
|
||||
trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes)
|
||||
|
||||
logger.info("trt_ep_options=%s", trt_ep_options)
|
||||
|
||||
return trt_ep_options
|
||||
|
||||
def allocate_buffers(self, shape_dict, device):
|
||||
super().allocate_buffers(shape_dict)
|
||||
|
||||
|
||||
class OrtTensorrtEngineBuilder(EngineBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_info: PipelineInfo,
|
||||
max_batch_size=16,
|
||||
device="cuda",
|
||||
use_cuda_graph=False,
|
||||
):
|
||||
"""
|
||||
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
||||
|
||||
Args:
|
||||
pipeline_info (PipelineInfo):
|
||||
Version and Type of pipeline.
|
||||
max_batch_size (int):
|
||||
Maximum batch size for dynamic batch engine.
|
||||
device (str):
|
||||
device to run.
|
||||
use_cuda_graph (bool):
|
||||
Use CUDA graph to capture engine execution and then launch inference
|
||||
"""
|
||||
super().__init__(
|
||||
EngineType.ORT_TRT,
|
||||
pipeline_info,
|
||||
max_batch_size=max_batch_size,
|
||||
device=device,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
|
||||
def has_engine_file(self, engine_path):
|
||||
if os.path.isdir(engine_path):
|
||||
children = os.scandir(engine_path)
|
||||
for entry in children:
|
||||
if entry.is_file() and entry.name.endswith(".engine"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_work_space_size(self, model_name, max_workspace_size):
|
||||
gibibyte = 2**30
|
||||
workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size
|
||||
if workspace_size == 0:
|
||||
_, free_mem, _ = cudart.cudaMemGetInfo()
|
||||
# The following logic are adopted from TensorRT demo diffusion.
|
||||
if free_mem > 6 * gibibyte:
|
||||
workspace_size = free_mem - 4 * gibibyte
|
||||
return workspace_size
|
||||
|
||||
def build_engines(
|
||||
self,
|
||||
engine_dir,
|
||||
framework_model_dir,
|
||||
onnx_dir,
|
||||
onnx_opset,
|
||||
opt_image_height,
|
||||
opt_image_width,
|
||||
opt_batch_size=1,
|
||||
static_batch=False,
|
||||
static_image_shape=True,
|
||||
max_workspace_size=0,
|
||||
device_id=0,
|
||||
timing_cache=None,
|
||||
):
|
||||
self.torch_device = torch.device("cuda", device_id)
|
||||
self.load_models(framework_model_dir)
|
||||
|
||||
if not os.path.isdir(engine_dir):
|
||||
os.makedirs(engine_dir)
|
||||
|
||||
if not os.path.isdir(onnx_dir):
|
||||
os.makedirs(onnx_dir)
|
||||
|
||||
# Load lora only when we need export text encoder or UNet to ONNX.
|
||||
load_lora = False
|
||||
if self.pipeline_info.lora_weights:
|
||||
for model_name, model_obj in self.models.items():
|
||||
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
|
||||
continue
|
||||
profile_id = model_obj.get_profile_id(
|
||||
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
|
||||
)
|
||||
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
||||
if not self.has_engine_file(engine_path):
|
||||
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
if not os.path.exists(onnx_path):
|
||||
load_lora = True
|
||||
break
|
||||
|
||||
# Export models to ONNX
|
||||
self.disable_torch_spda()
|
||||
pipe = self.load_pipeline_with_lora() if load_lora else None
|
||||
|
||||
for model_name, model_obj in self.models.items():
|
||||
if model_name == "vae" and self.vae_torch_fallback:
|
||||
continue
|
||||
|
||||
profile_id = model_obj.get_profile_id(
|
||||
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
|
||||
)
|
||||
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
||||
if not self.has_engine_file(engine_path):
|
||||
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
if not os.path.exists(onnx_path):
|
||||
logger.info(f"Exporting model: {onnx_path}")
|
||||
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
||||
|
||||
with torch.inference_mode(), torch.autocast("cuda"):
|
||||
inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=onnx_opset,
|
||||
do_constant_folding=True,
|
||||
input_names=model_obj.get_input_names(),
|
||||
output_names=model_obj.get_output_names(),
|
||||
dynamic_axes=model_obj.get_dynamic_axes(),
|
||||
)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
else:
|
||||
logger.info("Found cached model: %s", onnx_path)
|
||||
|
||||
# Optimize onnx
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
logger.info("Generating optimizing model: %s", onnx_opt_path)
|
||||
model_obj.optimize_trt(onnx_path, onnx_opt_path)
|
||||
else:
|
||||
logger.info("Found cached optimized model: %s", onnx_opt_path)
|
||||
self.enable_torch_spda()
|
||||
|
||||
built_engines = {}
|
||||
for model_name, model_obj in self.models.items():
|
||||
if model_name == "vae" and self.vae_torch_fallback:
|
||||
continue
|
||||
|
||||
profile_id = model_obj.get_profile_id(
|
||||
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
|
||||
)
|
||||
|
||||
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
||||
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
||||
if not self.has_engine_file(engine_path):
|
||||
logger.info(
|
||||
"Building TensorRT engine for %s from %s to %s. It can take a while to complete...",
|
||||
model_name,
|
||||
onnx_opt_path,
|
||||
engine_path,
|
||||
)
|
||||
else:
|
||||
logger.info("Reuse cached TensorRT engine in directory %s", engine_path)
|
||||
|
||||
input_profile = model_obj.get_input_profile(
|
||||
opt_batch_size,
|
||||
opt_image_height,
|
||||
opt_image_width,
|
||||
static_batch=static_batch,
|
||||
static_image_shape=static_image_shape,
|
||||
)
|
||||
|
||||
engine = OrtTensorrtEngine(
|
||||
engine_path,
|
||||
device_id,
|
||||
onnx_opt_path,
|
||||
fp16=True,
|
||||
input_profile=input_profile,
|
||||
workspace_size=self.get_work_space_size(model_name, max_workspace_size),
|
||||
enable_cuda_graph=self.use_cuda_graph,
|
||||
timing_cache_path=timing_cache,
|
||||
)
|
||||
|
||||
built_engines[model_name] = engine
|
||||
|
||||
self.engines = built_engines
|
||||
|
||||
def run_engine(self, model_name, feed_dict):
|
||||
return self.engines[model_name].infer(feed_dict)
|
||||
+395
@@ -0,0 +1,395 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
# Modified from TensorRT demo diffusion, which has the following license:
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import gc
|
||||
import os
|
||||
import pathlib
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from cuda import cudart
|
||||
from diffusion_models import PipelineInfo
|
||||
from engine_builder import EngineBuilder, EngineType
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
from polygraphy.backend.trt import (
|
||||
CreateConfig,
|
||||
ModifyNetworkOutputs,
|
||||
Profile,
|
||||
engine_from_bytes,
|
||||
engine_from_network,
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
|
||||
# Map of numpy dtype -> torch dtype
|
||||
numpy_to_torch_dtype_dict = {
|
||||
np.int32: torch.int32,
|
||||
np.int64: torch.int64,
|
||||
np.float16: torch.float16,
|
||||
np.float32: torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def _cuda_assert(cuda_ret):
|
||||
err = cuda_ret[0]
|
||||
if err != cudart.cudaError_t.cudaSuccess:
|
||||
raise RuntimeError(
|
||||
f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
|
||||
)
|
||||
if len(cuda_ret) > 1:
|
||||
return cuda_ret[1]
|
||||
return None
|
||||
|
||||
|
||||
class TensorrtEngine:
|
||||
def __init__(
|
||||
self,
|
||||
engine_path,
|
||||
):
|
||||
self.engine_path = engine_path
|
||||
self.engine = None
|
||||
self.context = None
|
||||
self.buffers = OrderedDict()
|
||||
self.tensors = OrderedDict()
|
||||
self.cuda_graph_instance = None
|
||||
|
||||
def __del__(self):
|
||||
del self.engine
|
||||
del self.context
|
||||
del self.buffers
|
||||
del self.tensors
|
||||
|
||||
def build(
|
||||
self,
|
||||
onnx_path,
|
||||
fp16,
|
||||
input_profile=None,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
update_output_names=None,
|
||||
):
|
||||
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
|
||||
p = Profile()
|
||||
if input_profile:
|
||||
for name, dims in input_profile.items():
|
||||
assert len(dims) == 3
|
||||
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
|
||||
|
||||
config_kwargs = {}
|
||||
if not enable_all_tactics:
|
||||
config_kwargs["tactic_sources"] = []
|
||||
|
||||
network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
|
||||
if update_output_names:
|
||||
print(f"Updating network outputs to {update_output_names}")
|
||||
network = ModifyNetworkOutputs(network, update_output_names)
|
||||
engine = engine_from_network(
|
||||
network,
|
||||
config=CreateConfig(
|
||||
fp16=fp16, refittable=False, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
|
||||
),
|
||||
save_timing_cache=timing_cache,
|
||||
)
|
||||
save_engine(engine, path=self.engine_path)
|
||||
|
||||
def load(self):
|
||||
print(f"Loading TensorRT engine: {self.engine_path}")
|
||||
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
|
||||
|
||||
def activate(self, reuse_device_memory=None):
|
||||
if reuse_device_memory:
|
||||
self.context = self.engine.create_execution_context_without_device_memory()
|
||||
self.context.device_memory = reuse_device_memory
|
||||
else:
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
def allocate_buffers(self, shape_dict=None, device="cuda"):
|
||||
for idx in range(self.engine.num_io_tensors):
|
||||
binding = self.engine[idx]
|
||||
if shape_dict and binding in shape_dict:
|
||||
shape = shape_dict[binding]
|
||||
else:
|
||||
shape = self.engine.get_binding_shape(binding)
|
||||
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
|
||||
if self.engine.binding_is_input(binding):
|
||||
self.context.set_binding_shape(idx, shape)
|
||||
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
|
||||
self.tensors[binding] = tensor
|
||||
|
||||
def infer(self, feed_dict, stream, use_cuda_graph=False):
|
||||
for name, buf in feed_dict.items():
|
||||
self.tensors[name].copy_(buf)
|
||||
|
||||
for name, tensor in self.tensors.items():
|
||||
self.context.set_tensor_address(name, tensor.data_ptr())
|
||||
|
||||
if use_cuda_graph:
|
||||
if self.cuda_graph_instance is not None:
|
||||
_cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
|
||||
_cuda_assert(cudart.cudaStreamSynchronize(stream))
|
||||
else:
|
||||
# do inference before CUDA graph capture
|
||||
noerror = self.context.execute_async_v3(stream)
|
||||
if not noerror:
|
||||
raise ValueError("ERROR: inference failed.")
|
||||
# capture cuda graph
|
||||
_cuda_assert(
|
||||
cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
|
||||
)
|
||||
self.context.execute_async_v3(stream)
|
||||
self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream))
|
||||
|
||||
from cuda import nvrtc # noqa: PLC0415
|
||||
|
||||
result, major, minor = nvrtc.nvrtcVersion()
|
||||
assert result == nvrtc.nvrtcResult(0)
|
||||
if major < 12:
|
||||
self.cuda_graph_instance = _cuda_assert(
|
||||
cudart.cudaGraphInstantiate(self.graph, b"", 0)
|
||||
) # cuda < 12
|
||||
else:
|
||||
self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12
|
||||
else:
|
||||
noerror = self.context.execute_async_v3(stream)
|
||||
if not noerror:
|
||||
raise ValueError("ERROR: inference failed.")
|
||||
|
||||
return self.tensors
|
||||
|
||||
|
||||
class TensorrtEngineBuilder(EngineBuilder):
|
||||
"""
|
||||
Helper class to hide the detail of TensorRT Engine from pipeline.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_info: PipelineInfo,
|
||||
max_batch_size=16,
|
||||
device="cuda",
|
||||
use_cuda_graph=False,
|
||||
):
|
||||
"""
|
||||
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
||||
|
||||
Args:
|
||||
pipeline_info (PipelineInfo):
|
||||
Version and Type of pipeline.
|
||||
max_batch_size (int):
|
||||
Maximum batch size for dynamic batch engine.
|
||||
device (str):
|
||||
device to run.
|
||||
use_cuda_graph (bool):
|
||||
Use CUDA graph to capture engine execution and then launch inference
|
||||
"""
|
||||
super().__init__(
|
||||
EngineType.TRT,
|
||||
pipeline_info,
|
||||
max_batch_size=max_batch_size,
|
||||
device=device,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
|
||||
self.stream = None
|
||||
self.shared_device_memory = None
|
||||
|
||||
def load_resources(self, image_height, image_width, batch_size):
|
||||
super().load_resources(image_height, image_width, batch_size)
|
||||
|
||||
self.stream = _cuda_assert(cudart.cudaStreamCreate())
|
||||
|
||||
def teardown(self):
|
||||
super().teardown()
|
||||
|
||||
if self.shared_device_memory:
|
||||
cudart.cudaFree(self.shared_device_memory)
|
||||
|
||||
cudart.cudaStreamDestroy(self.stream)
|
||||
del self.stream
|
||||
|
||||
def load_engines(
|
||||
self,
|
||||
engine_dir,
|
||||
framework_model_dir,
|
||||
onnx_dir,
|
||||
onnx_opset,
|
||||
opt_batch_size,
|
||||
opt_image_height,
|
||||
opt_image_width,
|
||||
static_batch=False,
|
||||
static_shape=True,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
):
|
||||
"""
|
||||
Build and load engines for TensorRT accelerated inference.
|
||||
Export ONNX models first, if applicable.
|
||||
|
||||
Args:
|
||||
engine_dir (str):
|
||||
Directory to write the TensorRT engines.
|
||||
framework_model_dir (str):
|
||||
Directory to write the framework model ckpt.
|
||||
onnx_dir (str):
|
||||
Directory to write the ONNX models.
|
||||
onnx_opset (int):
|
||||
ONNX opset version to export the models.
|
||||
opt_batch_size (int):
|
||||
Batch size to optimize for during engine building.
|
||||
opt_image_height (int):
|
||||
Image height to optimize for during engine building. Must be a multiple of 8.
|
||||
opt_image_width (int):
|
||||
Image width to optimize for during engine building. Must be a multiple of 8.
|
||||
static_batch (bool):
|
||||
Build engine only for specified opt_batch_size.
|
||||
static_shape (bool):
|
||||
Build engine only for specified opt_image_height & opt_image_width. Default = True.
|
||||
enable_all_tactics (bool):
|
||||
Enable all tactic sources during TensorRT engine builds.
|
||||
timing_cache (str):
|
||||
Path to the timing cache to accelerate build or None
|
||||
"""
|
||||
# Create directory
|
||||
for directory in [engine_dir, onnx_dir]:
|
||||
if not os.path.exists(directory):
|
||||
print(f"[I] Create directory: {directory}")
|
||||
pathlib.Path(directory).mkdir(parents=True)
|
||||
|
||||
self.load_models(framework_model_dir)
|
||||
|
||||
# Load lora only when we need export text encoder or UNet to ONNX.
|
||||
load_lora = False
|
||||
if self.pipeline_info.lora_weights:
|
||||
for model_name, model_obj in self.models.items():
|
||||
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
|
||||
continue
|
||||
profile_id = model_obj.get_profile_id(
|
||||
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
|
||||
)
|
||||
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
||||
if not os.path.exists(engine_path):
|
||||
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
if not os.path.exists(onnx_path):
|
||||
load_lora = True
|
||||
break
|
||||
|
||||
# Export models to ONNX
|
||||
self.disable_torch_spda()
|
||||
pipe = self.load_pipeline_with_lora() if load_lora else None
|
||||
|
||||
for model_name, model_obj in self.models.items():
|
||||
if model_name == "vae" and self.vae_torch_fallback:
|
||||
continue
|
||||
profile_id = model_obj.get_profile_id(
|
||||
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
|
||||
)
|
||||
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
||||
if not os.path.exists(engine_path):
|
||||
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
if not os.path.exists(onnx_path):
|
||||
print(f"Exporting model: {onnx_path}")
|
||||
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
||||
|
||||
with torch.inference_mode(), torch.autocast("cuda"):
|
||||
inputs = model_obj.get_sample_input(1, opt_image_height, opt_image_width)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=onnx_opset,
|
||||
do_constant_folding=True,
|
||||
input_names=model_obj.get_input_names(),
|
||||
output_names=model_obj.get_output_names(),
|
||||
dynamic_axes=model_obj.get_dynamic_axes(),
|
||||
)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
else:
|
||||
print(f"Found cached model: {onnx_path}")
|
||||
|
||||
# Optimize onnx
|
||||
if not os.path.exists(onnx_opt_path):
|
||||
print(f"Generating optimizing model: {onnx_opt_path}")
|
||||
model_obj.optimize_trt(onnx_path, onnx_opt_path)
|
||||
else:
|
||||
print(f"Found cached optimized model: {onnx_opt_path} ")
|
||||
self.enable_torch_spda()
|
||||
|
||||
# Build TensorRT engines
|
||||
for model_name, model_obj in self.models.items():
|
||||
if model_name == "vae" and self.vae_torch_fallback:
|
||||
continue
|
||||
profile_id = model_obj.get_profile_id(
|
||||
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
|
||||
)
|
||||
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
||||
engine = TensorrtEngine(engine_path)
|
||||
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
||||
|
||||
if not os.path.exists(engine.engine_path):
|
||||
engine.build(
|
||||
onnx_opt_path,
|
||||
fp16=True,
|
||||
input_profile=model_obj.get_input_profile(
|
||||
opt_batch_size,
|
||||
opt_image_height,
|
||||
opt_image_width,
|
||||
static_batch,
|
||||
static_shape,
|
||||
),
|
||||
enable_all_tactics=enable_all_tactics,
|
||||
timing_cache=timing_cache,
|
||||
update_output_names=None,
|
||||
)
|
||||
self.engines[model_name] = engine
|
||||
|
||||
# Load TensorRT engines
|
||||
for model_name in self.models:
|
||||
if model_name == "vae" and self.vae_torch_fallback:
|
||||
continue
|
||||
self.engines[model_name].load()
|
||||
|
||||
def max_device_memory(self):
|
||||
max_device_memory = 0
|
||||
for engine in self.engines.values():
|
||||
max_device_memory = max(max_device_memory, engine.engine.device_memory_size)
|
||||
return max_device_memory
|
||||
|
||||
def activate_engines(self, shared_device_memory=None):
|
||||
if shared_device_memory is None:
|
||||
max_device_memory = self.max_device_memory()
|
||||
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
|
||||
self.shared_device_memory = shared_device_memory
|
||||
# Load and activate TensorRT engines
|
||||
for engine in self.engines.values():
|
||||
engine.activate(reuse_device_memory=self.shared_device_memory)
|
||||
|
||||
def run_engine(self, model_name, feed_dict):
|
||||
return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph)
|
||||
+108
@@ -0,0 +1,108 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import logging
|
||||
|
||||
from diffusion_models import PipelineInfo
|
||||
from engine_builder import EngineBuilder, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TorchEngineBuilder(EngineBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_info: PipelineInfo,
|
||||
max_batch_size=16,
|
||||
device="cuda",
|
||||
use_cuda_graph=False,
|
||||
):
|
||||
"""
|
||||
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
||||
|
||||
Args:
|
||||
pipeline_info (PipelineInfo):
|
||||
Version and Type of pipeline.
|
||||
max_batch_size (int):
|
||||
Maximum batch size for dynamic batch engine.
|
||||
device (str):
|
||||
device to run.
|
||||
use_cuda_graph (bool):
|
||||
Use CUDA graph to capture engine execution and then launch inference
|
||||
"""
|
||||
super().__init__(
|
||||
EngineType.TORCH,
|
||||
pipeline_info,
|
||||
max_batch_size=max_batch_size,
|
||||
device=device,
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
|
||||
self.compile_config = {}
|
||||
if use_cuda_graph:
|
||||
self.compile_config = {
|
||||
"clip": {"mode": "reduce-overhead", "dynamic": False},
|
||||
"clip2": {"mode": "reduce-overhead", "dynamic": False},
|
||||
"unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
|
||||
"unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
|
||||
"vae": {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False},
|
||||
}
|
||||
|
||||
def build_engines(
|
||||
self,
|
||||
framework_model_dir: str,
|
||||
):
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
self.torch_device = torch.device("cuda", torch.cuda.current_device())
|
||||
self.load_models(framework_model_dir)
|
||||
|
||||
pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None
|
||||
|
||||
built_engines = {}
|
||||
for model_name, model_obj in self.models.items():
|
||||
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
||||
if self.pipeline_info.is_xl() and not self.custom_fp16_vae:
|
||||
model = model.to(device=self.torch_device, dtype=torch.float32)
|
||||
else:
|
||||
model = model.to(device=self.torch_device, dtype=torch.float16)
|
||||
|
||||
if model_name in self.compile_config:
|
||||
compile_config = self.compile_config[model_name]
|
||||
if model_name in ["unet", "unetxl"]:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
engine = torch.compile(model, **compile_config)
|
||||
built_engines[model_name] = engine
|
||||
else: # eager mode
|
||||
built_engines[model_name] = model
|
||||
|
||||
self.engines = built_engines
|
||||
|
||||
def run_engine(self, model_name, feed_dict):
|
||||
if model_name in ["unet", "unetxl"]:
|
||||
if "controlnet_images" in feed_dict:
|
||||
return {"latent": self.engines[model_name](**feed_dict)}
|
||||
|
||||
if model_name == "unetxl":
|
||||
added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]}
|
||||
return {
|
||||
"latent": self.engines[model_name](
|
||||
feed_dict["sample"],
|
||||
feed_dict["timestep"],
|
||||
feed_dict["encoder_hidden_states"],
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
}
|
||||
|
||||
return {
|
||||
"latent": self.engines[model_name](
|
||||
feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False
|
||||
)[0]
|
||||
}
|
||||
|
||||
if model_name in ["vae_encoder"]:
|
||||
return {"latent": self.engines[model_name](feed_dict["images"])}
|
||||
|
||||
raise RuntimeError(f"Shall not reach here: {model_name}")
|
||||
+584
@@ -0,0 +1,584 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
#
|
||||
# This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
|
||||
#
|
||||
# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint
|
||||
# to float32 onnx models.
|
||||
#
|
||||
# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16
|
||||
# like the following:
|
||||
# python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16
|
||||
#
|
||||
# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support
|
||||
# for the fused operators. The users could disable the operator fusion manually to workaround.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import coloredlogs
|
||||
import onnx
|
||||
from fusion_options import FusionOptions
|
||||
from onnx_model_clip import ClipOnnxModel
|
||||
from onnx_model_mmdit import MmditOnnxModel
|
||||
from onnx_model_t5 import T5OnnxModel
|
||||
from onnx_model_unet import UnetOnnxModel
|
||||
from onnx_model_vae import VaeOnnxModel
|
||||
from optimizer import optimize_by_onnxruntime, optimize_model
|
||||
from packaging import version
|
||||
|
||||
import onnxruntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def has_external_data(onnx_model_path):
|
||||
original_model = onnx.load_model(str(onnx_model_path), load_external_data=False)
|
||||
for initializer in original_model.graph.initializer:
|
||||
if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_sd_3(source_dir: Path):
|
||||
return (source_dir / "text_encoder_3").exists()
|
||||
|
||||
|
||||
def is_sdxl(source_dir: Path):
|
||||
return (
|
||||
(source_dir / "text_encoder_2").exists()
|
||||
and not (source_dir / "text_encoder_3").exists()
|
||||
and not (source_dir / "transformer").exists()
|
||||
)
|
||||
|
||||
|
||||
def is_flux(source_dir: Path):
|
||||
return (
|
||||
(source_dir / "text_encoder_2").exists()
|
||||
and not (source_dir / "text_encoder_3").exists()
|
||||
and (source_dir / "transformer").exists()
|
||||
)
|
||||
|
||||
|
||||
def _classify_pipeline_type(source_dir: Path):
|
||||
# May also check _class_name in model_index.json like `StableDiffusion3Pipeline` or `FluxPipeline` etc to classify.
|
||||
if is_sd_3(source_dir):
|
||||
return "sd3"
|
||||
|
||||
if is_flux(source_dir):
|
||||
return "flux"
|
||||
|
||||
if is_sdxl(source_dir):
|
||||
return "sdxl"
|
||||
|
||||
# sd 1.x and 2.x
|
||||
return "sd"
|
||||
|
||||
|
||||
def _get_model_list(pipeline_type: str):
|
||||
if pipeline_type == "sd3":
|
||||
return ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"]
|
||||
|
||||
if pipeline_type == "flux":
|
||||
return ["text_encoder", "text_encoder_2", "transformer", "vae_encoder", "vae_decoder"]
|
||||
|
||||
if pipeline_type == "sdxl":
|
||||
return ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"]
|
||||
|
||||
assert pipeline_type == "sd"
|
||||
return ["text_encoder", "unet", "vae_encoder", "vae_decoder"]
|
||||
|
||||
|
||||
def _optimize_sd_pipeline(
|
||||
source_dir: Path,
|
||||
target_dir: Path,
|
||||
pipeline_type: str,
|
||||
model_list: list[str],
|
||||
use_external_data_format: bool | None,
|
||||
float16: bool,
|
||||
bfloat16: bool,
|
||||
force_fp32_ops: list[str],
|
||||
enable_runtime_optimization: bool,
|
||||
args,
|
||||
):
|
||||
"""Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
|
||||
|
||||
Args:
|
||||
source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
|
||||
target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
|
||||
model_list (List[str]): list of directory names with onnx model.
|
||||
use_external_data_format (Optional[bool]): use external data format.
|
||||
float16 (bool): use half precision
|
||||
bfloat16 (bool): use bfloat16 as fallback if float16 is also provided.
|
||||
force_fp32_ops(List[str]): operators that are forced to run in float32.
|
||||
enable_runtime_optimization(bool): run graph optimization using Onnx Runtime.
|
||||
|
||||
Raises:
|
||||
RuntimeError: input onnx model does not exist
|
||||
RuntimeError: output onnx model path existed
|
||||
"""
|
||||
is_flux_pipeline = pipeline_type == "flux"
|
||||
model_type_mapping = {
|
||||
"transformer": "mmdit",
|
||||
"unet": "unet",
|
||||
"vae_encoder": "vae",
|
||||
"vae_decoder": "vae",
|
||||
"text_encoder": "clip",
|
||||
"text_encoder_2": "t5" if is_flux_pipeline else "clip",
|
||||
"text_encoder_3": "t5", # t5-v1_1-xxl is used in SD 3.x text_encoder_3 and Flux text_encoder_2.
|
||||
"safety_checker": "unet",
|
||||
}
|
||||
|
||||
model_type_class_mapping = {
|
||||
"unet": UnetOnnxModel,
|
||||
"vae": VaeOnnxModel,
|
||||
"clip": ClipOnnxModel,
|
||||
"t5": T5OnnxModel,
|
||||
"mmdit": MmditOnnxModel,
|
||||
}
|
||||
|
||||
force_fp32_operators = {
|
||||
"unet": [],
|
||||
"vae_encoder": [],
|
||||
"vae_decoder": [],
|
||||
"text_encoder": [],
|
||||
"text_encoder_2": [],
|
||||
"safety_checker": [],
|
||||
"text_encoder_3": [],
|
||||
"transformer": [],
|
||||
}
|
||||
|
||||
# The node block list is generated by running the fp32 model and get statistics of node inputs and outputs.
|
||||
# Nodes with any input or output of float or double data type, but value ouf of range of float16 are candidates.
|
||||
# python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp32_opt
|
||||
# export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1
|
||||
# export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1
|
||||
# export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1
|
||||
# python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >stdout.txt 2>stderr.txt
|
||||
# Warning: The node name might change in different export settings. See benchmark_flux.sh for the settings.
|
||||
flux_node_block_list = {
|
||||
"text_encoder_2": [
|
||||
"/encoder/block.10/layer.1/DenseReluDense/wo/MatMul",
|
||||
"SkipLayerNorm_20",
|
||||
"SkipLayerNorm_21",
|
||||
"SkipLayerNorm_22",
|
||||
"SkipLayerNorm_23",
|
||||
"SkipLayerNorm_24",
|
||||
"SkipLayerNorm_25",
|
||||
"SkipLayerNorm_26",
|
||||
"SkipLayerNorm_27",
|
||||
"SkipLayerNorm_28",
|
||||
"SkipLayerNorm_29",
|
||||
"SkipLayerNorm_30",
|
||||
"SkipLayerNorm_31",
|
||||
"SkipLayerNorm_32",
|
||||
"SkipLayerNorm_33",
|
||||
"SkipLayerNorm_34",
|
||||
"SkipLayerNorm_35",
|
||||
"SkipLayerNorm_36",
|
||||
"SkipLayerNorm_37",
|
||||
"SkipLayerNorm_38",
|
||||
"SkipLayerNorm_39",
|
||||
"SkipLayerNorm_40",
|
||||
"SkipLayerNorm_41",
|
||||
"SkipLayerNorm_42",
|
||||
"SkipLayerNorm_43",
|
||||
"SkipLayerNorm_44",
|
||||
"SkipLayerNorm_45",
|
||||
"/encoder/block.23/layer.1/DenseReluDense/wo/MatMul",
|
||||
"SkipLayerNorm_46",
|
||||
],
|
||||
"vae_decoder": [
|
||||
"/decoder/mid_block/attentions.0/MatMul",
|
||||
"/decoder/mid_block/attentions.0/Softmax",
|
||||
],
|
||||
"transformer": [
|
||||
"/transformer_blocks.18/Mul_5",
|
||||
"/transformer_blocks.18/Add_7",
|
||||
"/Concat_1",
|
||||
"LayerNorm_76",
|
||||
"/single_transformer_blocks.0/Add",
|
||||
"LayerNorm_77",
|
||||
"/single_transformer_blocks.1/Add",
|
||||
"LayerNorm_78",
|
||||
"/single_transformer_blocks.2/Add",
|
||||
"LayerNorm_79",
|
||||
"/single_transformer_blocks.3/Add",
|
||||
"LayerNorm_80",
|
||||
"/single_transformer_blocks.4/Add",
|
||||
"LayerNorm_81",
|
||||
"/single_transformer_blocks.5/Add",
|
||||
"LayerNorm_82",
|
||||
"/single_transformer_blocks.6/Add",
|
||||
"LayerNorm_83",
|
||||
"/single_transformer_blocks.7/Add",
|
||||
"LayerNorm_84",
|
||||
"/single_transformer_blocks.8/Add",
|
||||
"LayerNorm_85",
|
||||
"/single_transformer_blocks.9/Add",
|
||||
"LayerNorm_86",
|
||||
"/single_transformer_blocks.10/Add",
|
||||
"LayerNorm_87",
|
||||
"/single_transformer_blocks.11/Add",
|
||||
"LayerNorm_88",
|
||||
"/single_transformer_blocks.12/Add",
|
||||
"LayerNorm_89",
|
||||
"/single_transformer_blocks.13/Add",
|
||||
"LayerNorm_90",
|
||||
"/single_transformer_blocks.14/Add",
|
||||
"LayerNorm_91",
|
||||
"/single_transformer_blocks.15/Add",
|
||||
"LayerNorm_92",
|
||||
"/single_transformer_blocks.16/Add",
|
||||
"LayerNorm_93",
|
||||
"/single_transformer_blocks.17/Add",
|
||||
"LayerNorm_94",
|
||||
"/single_transformer_blocks.18/Add",
|
||||
"LayerNorm_95",
|
||||
"/single_transformer_blocks.19/Add",
|
||||
"LayerNorm_96",
|
||||
"/single_transformer_blocks.20/Add",
|
||||
"LayerNorm_97",
|
||||
"/single_transformer_blocks.21/Add",
|
||||
"LayerNorm_98",
|
||||
"/single_transformer_blocks.22/Add",
|
||||
"LayerNorm_99",
|
||||
"/single_transformer_blocks.23/Add",
|
||||
"LayerNorm_100",
|
||||
"/single_transformer_blocks.24/Add",
|
||||
"LayerNorm_101",
|
||||
"/single_transformer_blocks.25/Add",
|
||||
"LayerNorm_102",
|
||||
"/single_transformer_blocks.26/Add",
|
||||
"LayerNorm_103",
|
||||
"/single_transformer_blocks.27/Add",
|
||||
"LayerNorm_104",
|
||||
"/single_transformer_blocks.28/Add",
|
||||
"LayerNorm_105",
|
||||
"/single_transformer_blocks.29/Add",
|
||||
"LayerNorm_106",
|
||||
"/single_transformer_blocks.30/Add",
|
||||
"LayerNorm_107",
|
||||
"/single_transformer_blocks.31/Add",
|
||||
"LayerNorm_108",
|
||||
"/single_transformer_blocks.32/Add",
|
||||
"LayerNorm_109",
|
||||
"/single_transformer_blocks.33/Add",
|
||||
"LayerNorm_110",
|
||||
"/single_transformer_blocks.34/Add",
|
||||
"LayerNorm_111",
|
||||
"/single_transformer_blocks.35/Add",
|
||||
"LayerNorm_112",
|
||||
"/single_transformer_blocks.36/Add",
|
||||
"LayerNorm_113",
|
||||
"/single_transformer_blocks.37/Add",
|
||||
"/Shape",
|
||||
"/Slice",
|
||||
],
|
||||
}
|
||||
|
||||
sd3_node_block_list = {"text_encoder_3": flux_node_block_list["text_encoder_2"]}
|
||||
|
||||
if force_fp32_ops:
|
||||
for fp32_operator in force_fp32_ops:
|
||||
parts = fp32_operator.split(":")
|
||||
if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()):
|
||||
force_fp32_operators[parts[0]].append(parts[1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}"
|
||||
)
|
||||
|
||||
op_counters = {}
|
||||
for name, model_type in model_type_mapping.items():
|
||||
onnx_model_path = source_dir / name / "model.onnx"
|
||||
if not os.path.exists(onnx_model_path):
|
||||
if name != "safety_checker" and name in model_list:
|
||||
logger.warning("input onnx model does not exist: %s", onnx_model_path)
|
||||
# some model are optional so we do not raise error here.
|
||||
continue
|
||||
|
||||
# Prepare output directory
|
||||
optimized_model_path = target_dir / name / "model.onnx"
|
||||
if os.path.exists(optimized_model_path):
|
||||
if not args.overwrite:
|
||||
logger.warning("Skipped optimization since the target file existed: %s", optimized_model_path)
|
||||
continue
|
||||
output_dir = optimized_model_path.parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if use_external_data_format is None:
|
||||
use_external_data_format = has_external_data(onnx_model_path)
|
||||
|
||||
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
|
||||
logger.info("Optimize %s ...", onnx_model_path)
|
||||
|
||||
args.model_type = model_type
|
||||
fusion_options = FusionOptions.parse(args)
|
||||
|
||||
if model_type in ["unet"]:
|
||||
# Some optimizations are not available in v1.14 or older version: packed QKV and BiasAdd
|
||||
has_all_optimizations = version.parse(onnxruntime.__version__) >= version.parse("1.15.0")
|
||||
fusion_options.enable_packed_kv = float16 and fusion_options.enable_packed_kv
|
||||
fusion_options.enable_packed_qkv = float16 and has_all_optimizations and fusion_options.enable_packed_qkv
|
||||
fusion_options.enable_bias_add = has_all_optimizations and fusion_options.enable_bias_add
|
||||
|
||||
m = optimize_model(
|
||||
str(onnx_model_path),
|
||||
model_type=model_type,
|
||||
num_heads=0, # will be deduced from graph
|
||||
hidden_size=0, # will be deduced from graph
|
||||
opt_level=0,
|
||||
optimization_options=fusion_options,
|
||||
use_gpu=True,
|
||||
provider=args.provider,
|
||||
)
|
||||
|
||||
if float16:
|
||||
model_node_block_list = (
|
||||
flux_node_block_list if is_flux_pipeline else sd3_node_block_list if pipeline_type == "sd3" else {}
|
||||
)
|
||||
if name in model_node_block_list:
|
||||
# Opset 12 does not support bfloat16.
|
||||
# By default, optimum exports T5 model with opset 12. So we need to check the opset version.
|
||||
use_bfloat16 = bfloat16
|
||||
if use_bfloat16:
|
||||
for opset in m.model.opset_import:
|
||||
if opset.domain in ["", "ai.onnx"] and opset.version < 13:
|
||||
logger.warning(
|
||||
"onnx model requires opset 13 or higher to use bfloat16. Fall back to float32."
|
||||
)
|
||||
use_bfloat16 = False
|
||||
|
||||
m.convert_float_to_float16(
|
||||
keep_io_types=False,
|
||||
node_block_list=model_node_block_list[name],
|
||||
use_bfloat16_as_blocked_nodes_dtype=use_bfloat16,
|
||||
)
|
||||
# For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
|
||||
elif pipeline_type in ["sdxl"] and name in ["vae_decoder"]:
|
||||
logger.info("Skip converting %s to float16 to avoid NaN", name)
|
||||
else:
|
||||
logger.info("Convert %s to float16 ...", name)
|
||||
m.convert_float_to_float16(
|
||||
keep_io_types=False,
|
||||
op_block_list=force_fp32_operators[name],
|
||||
)
|
||||
|
||||
if enable_runtime_optimization:
|
||||
# Use this step to see the final graph that executed by Onnx Runtime.
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save to a temporary file so that we can load it with Onnx Runtime.
|
||||
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
|
||||
tmp_model_path = Path(tmp_dir) / "model.onnx"
|
||||
m.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
|
||||
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
|
||||
optimize_by_onnxruntime(
|
||||
str(tmp_model_path),
|
||||
use_gpu=True,
|
||||
provider=args.provider,
|
||||
optimized_model_path=str(ort_optimized_model_path),
|
||||
save_as_external_data=use_external_data_format,
|
||||
)
|
||||
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
|
||||
m = model_type_class_mapping[model_type](model)
|
||||
|
||||
m.get_operator_statistics()
|
||||
op_counters[name] = m.get_fused_operator_statistics()
|
||||
m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
|
||||
logger.info("%s is optimized", name)
|
||||
logger.info("*" * 20)
|
||||
|
||||
return op_counters
|
||||
|
||||
|
||||
def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: list[str]):
|
||||
"""Copy extra directory that does not have onnx model
|
||||
|
||||
Args:
|
||||
source_dir (Path): source directory
|
||||
target_dir (Path): target directory
|
||||
model_list (List[str]): list of directory names with onnx model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: source path does not exist
|
||||
"""
|
||||
extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"]
|
||||
|
||||
for name in extra_dirs:
|
||||
source_path = source_dir / name
|
||||
if not os.path.exists(source_path):
|
||||
continue
|
||||
|
||||
target_path = target_dir / name
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path)
|
||||
shutil.copytree(source_path, target_path)
|
||||
logger.info("%s => %s", source_path, target_path)
|
||||
|
||||
extra_files = ["model_index.json"]
|
||||
for name in extra_files:
|
||||
source_path = source_dir / name
|
||||
if not os.path.exists(source_path):
|
||||
raise RuntimeError(f"source path does not exist: {source_path}")
|
||||
|
||||
target_path = target_dir / name
|
||||
shutil.copyfile(source_path, target_path)
|
||||
logger.info("%s => %s", source_path, target_path)
|
||||
|
||||
# Some directory are optional
|
||||
for onnx_model_dir in model_list:
|
||||
source_path = source_dir / onnx_model_dir / "config.json"
|
||||
target_path = target_dir / onnx_model_dir / "config.json"
|
||||
if source_path.exists():
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(source_path, target_path)
|
||||
logger.info("%s => %s", source_path, target_path)
|
||||
|
||||
|
||||
def optimize_stable_diffusion_pipeline(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
overwrite: bool,
|
||||
use_external_data_format: bool | None,
|
||||
float16: bool,
|
||||
enable_runtime_optimization: bool,
|
||||
args,
|
||||
):
|
||||
if os.path.exists(output_dir):
|
||||
if overwrite:
|
||||
shutil.rmtree(output_dir, ignore_errors=True)
|
||||
|
||||
source_dir = Path(input_dir)
|
||||
target_dir = Path(output_dir)
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pipeline_type = _classify_pipeline_type(source_dir)
|
||||
model_list = _get_model_list(pipeline_type)
|
||||
|
||||
_copy_extra_directory(source_dir, target_dir, model_list)
|
||||
|
||||
return _optimize_sd_pipeline(
|
||||
source_dir,
|
||||
target_dir,
|
||||
pipeline_type,
|
||||
model_list,
|
||||
use_external_data_format,
|
||||
float16,
|
||||
args.bfloat16,
|
||||
args.force_fp32_ops,
|
||||
enable_runtime_optimization,
|
||||
args,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments(argv: list[str] | None = None):
|
||||
"""Parse arguments
|
||||
|
||||
Returns:
|
||||
Namespace: arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--input",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Root of output directory of stable diffusion onnx pipeline with optimized models.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--float16",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Output models of float16, except some nodes falls back to float32 or bfloat16 to avoid overflow.",
|
||||
)
|
||||
parser.set_defaults(float16=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--bfloat16",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Allow bfloat16 as fallback if --float16 is also provided.",
|
||||
)
|
||||
parser.set_defaults(bfloat16=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--force_fp32_ops",
|
||||
required=False,
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--inspect",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Save the optimized graph from Onnx Runtime. "
|
||||
"This option has no impact on inference performance except it might reduce session creation time.",
|
||||
)
|
||||
parser.set_defaults(inspect=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Overwrite exists files.",
|
||||
)
|
||||
parser.set_defaults(overwrite=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--use_external_data_format",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Onnx model larger than 2GB need to use external data format. "
|
||||
"If specified, save each onnx model to two files: one for onnx graph, another for weights. "
|
||||
"If not specified, use same format as original model by default. ",
|
||||
)
|
||||
parser.set_defaults(use_external_data_format=None)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Execution provider to use.",
|
||||
)
|
||||
|
||||
FusionOptions.add_arguments(parser)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
return args
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None):
|
||||
args = parse_arguments(argv)
|
||||
|
||||
logger.info("Arguments: %s", str(args))
|
||||
|
||||
# Return op counters for testing purpose.
|
||||
return optimize_stable_diffusion_pipeline(
|
||||
args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
|
||||
main()
|
||||
+136
@@ -0,0 +1,136 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
ONNX Model Optimizer for Stable Diffusion
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
from packaging import version
|
||||
|
||||
from onnxruntime.transformers.fusion_options import FusionOptions
|
||||
from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
|
||||
from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
|
||||
from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
|
||||
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrtStableDiffusionOptimizer:
|
||||
def __init__(self, model_type: str):
|
||||
assert model_type in ["vae", "unet", "clip"]
|
||||
self.model_type = model_type
|
||||
self.model_type_class_mapping = {
|
||||
"unet": UnetOnnxModel,
|
||||
"vae": VaeOnnxModel,
|
||||
"clip": ClipOnnxModel,
|
||||
}
|
||||
|
||||
def _optimize_by_ort(self, onnx_model, use_external_data_format, tmp_dir):
|
||||
# Save to a temporary file so that we can load it with Onnx Runtime.
|
||||
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
|
||||
tmp_model_path = Path(tmp_dir) / "model.onnx"
|
||||
onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
|
||||
|
||||
del onnx_model
|
||||
gc.collect()
|
||||
|
||||
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
|
||||
optimize_by_onnxruntime(
|
||||
str(tmp_model_path),
|
||||
use_gpu=True,
|
||||
optimized_model_path=str(ort_optimized_model_path),
|
||||
save_as_external_data=use_external_data_format,
|
||||
external_data_filename="optimized.onnx_data",
|
||||
)
|
||||
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
|
||||
return self.model_type_class_mapping[self.model_type](model)
|
||||
|
||||
def optimize_by_ort(self, onnx_model, use_external_data_format=False, tmp_dir=None):
|
||||
# Use this step to see the final graph that executed by Onnx Runtime.
|
||||
if tmp_dir is None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._optimize_by_ort(onnx_model, use_external_data_format, temp_dir)
|
||||
else:
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
model = self._optimize_by_ort(onnx_model, use_external_data_format, tmp_dir)
|
||||
shutil.rmtree(tmp_dir)
|
||||
return model
|
||||
|
||||
def optimize(
|
||||
self,
|
||||
input_fp32_onnx_path,
|
||||
optimized_onnx_path,
|
||||
float16=True,
|
||||
keep_io_types=False,
|
||||
fp32_op_list=None,
|
||||
keep_outputs=None,
|
||||
optimize_by_ort=True,
|
||||
optimize_by_fusion=True,
|
||||
final_target_float16=True,
|
||||
tmp_dir=None,
|
||||
):
|
||||
"""Optimize onnx model using ONNX Runtime transformers optimizer"""
|
||||
logger.info(f"Optimize {input_fp32_onnx_path}...")
|
||||
|
||||
if optimize_by_fusion:
|
||||
fusion_options = FusionOptions(self.model_type)
|
||||
|
||||
# It is allowed float16=False and final_target_float16=True, for using fp32 as intermediate optimization step.
|
||||
# For rare fp32 use case, we can disable packed kv/qkv since there is no fp32 TRT fused attention kernel.
|
||||
if self.model_type in ["unet"] and not final_target_float16:
|
||||
fusion_options.enable_packed_kv = False
|
||||
fusion_options.enable_packed_qkv = False
|
||||
|
||||
m = optimize_model(
|
||||
input_fp32_onnx_path,
|
||||
model_type=self.model_type,
|
||||
num_heads=0, # will be deduced from graph
|
||||
hidden_size=0, # will be deduced from graph
|
||||
opt_level=0,
|
||||
optimization_options=fusion_options,
|
||||
use_gpu=True,
|
||||
)
|
||||
else:
|
||||
model = onnx.load_model(input_fp32_onnx_path, load_external_data=True)
|
||||
m = self.model_type_class_mapping[self.model_type](model)
|
||||
|
||||
if keep_outputs:
|
||||
m.prune_graph(outputs=keep_outputs)
|
||||
|
||||
model_size = m.model.ByteSize()
|
||||
|
||||
# model size might be negative (overflow?) in Windows.
|
||||
use_external_data_format = model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF
|
||||
|
||||
# Note that ORT < 1.16 could not save model larger than 2GB.
|
||||
# This step is is optional since it has no impact on inference latency.
|
||||
# The optimized model is not portable. It could only run in the same execution provider (CUDA EP in this case).
|
||||
# When the model has been optimized by onnxruntime, we can disable optimization in SessionOption
|
||||
# to save session creation time. Another benefit is to inspect the final graph for developing purpose.
|
||||
from onnxruntime import __version__ as ort_version # noqa: PLC0415
|
||||
|
||||
if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format):
|
||||
m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format, tmp_dir=tmp_dir)
|
||||
|
||||
if float16:
|
||||
logger.info("Convert to float16 ...")
|
||||
m.convert_float_to_float16(
|
||||
keep_io_types=keep_io_types,
|
||||
op_block_list=fp32_op_list,
|
||||
)
|
||||
|
||||
m.get_operator_statistics()
|
||||
m.get_fused_operator_statistics()
|
||||
m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
|
||||
logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)
|
||||
+831
@@ -0,0 +1,831 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
# Modified from TensorRT demo diffusion, which has the following license:
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import nvtx
|
||||
import torch
|
||||
from cuda import cudart
|
||||
from diffusion_models import PipelineInfo, get_tokenizer
|
||||
from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler
|
||||
from engine_builder import EngineType
|
||||
from engine_builder_ort_cuda import OrtCudaEngineBuilder
|
||||
from engine_builder_ort_trt import OrtTensorrtEngineBuilder
|
||||
from engine_builder_tensorrt import TensorrtEngineBuilder
|
||||
from engine_builder_torch import TorchEngineBuilder
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class StableDiffusionPipeline:
|
||||
"""
|
||||
Stable Diffusion pipeline using TensorRT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_info: PipelineInfo,
|
||||
max_batch_size=16,
|
||||
scheduler="DDIM",
|
||||
device="cuda",
|
||||
output_dir=".",
|
||||
verbose=False,
|
||||
nvtx_profile=False,
|
||||
use_cuda_graph=False,
|
||||
framework_model_dir="pytorch_model",
|
||||
engine_type: EngineType = EngineType.ORT_CUDA,
|
||||
):
|
||||
"""
|
||||
Initializes the Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
pipeline_info (PipelineInfo):
|
||||
Version and Type of pipeline.
|
||||
max_batch_size (int):
|
||||
Maximum batch size for dynamic batch engine.
|
||||
scheduler (str):
|
||||
The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC, LCM].
|
||||
device (str):
|
||||
PyTorch device to run inference. Default: 'cuda'
|
||||
output_dir (str):
|
||||
Output directory for log files and image artifacts
|
||||
verbose (bool):
|
||||
Enable verbose logging.
|
||||
nvtx_profile (bool):
|
||||
Insert NVTX profiling markers.
|
||||
use_cuda_graph (bool):
|
||||
Use CUDA graph to capture engine execution and then launch inference
|
||||
framework_model_dir (str):
|
||||
cache directory for framework checkpoints
|
||||
engine_type (EngineType)
|
||||
backend engine type like ORT_TRT or TRT
|
||||
"""
|
||||
|
||||
self.pipeline_info = pipeline_info
|
||||
self.version = pipeline_info.version
|
||||
|
||||
self.vae_scaling_factor = pipeline_info.vae_scaling_factor()
|
||||
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
self.framework_model_dir = framework_model_dir
|
||||
self.output_dir = output_dir
|
||||
for directory in [self.framework_model_dir, self.output_dir]:
|
||||
if not os.path.exists(directory):
|
||||
print(f"[I] Create directory: {directory}")
|
||||
pathlib.Path(directory).mkdir(parents=True)
|
||||
|
||||
self.device = device
|
||||
self.torch_device = torch.device(device, torch.cuda.current_device())
|
||||
self.verbose = verbose
|
||||
self.nvtx_profile = nvtx_profile
|
||||
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
|
||||
self.tokenizer = None
|
||||
self.tokenizer2 = None
|
||||
|
||||
self.generator = torch.Generator(device="cuda")
|
||||
self.actual_steps = None
|
||||
|
||||
self.current_scheduler = None
|
||||
self.set_scheduler(scheduler)
|
||||
|
||||
# backend engine
|
||||
self.engine_type = engine_type
|
||||
if engine_type == EngineType.TRT:
|
||||
self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
|
||||
elif engine_type == EngineType.ORT_TRT:
|
||||
self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
|
||||
elif engine_type == EngineType.ORT_CUDA:
|
||||
self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
|
||||
elif engine_type == EngineType.TORCH:
|
||||
self.backend = TorchEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
|
||||
else:
|
||||
raise RuntimeError(f"Backend engine type {engine_type.name} is not supported")
|
||||
|
||||
# Load text tokenizer
|
||||
if not self.pipeline_info.is_xl_refiner():
|
||||
self.tokenizer = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer")
|
||||
|
||||
if self.pipeline_info.is_xl():
|
||||
self.tokenizer2 = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer_2")
|
||||
|
||||
self.control_image_processor = None
|
||||
if self.pipeline_info.is_xl() and self.pipeline_info.controlnet:
|
||||
from diffusers.image_processor import VaeImageProcessor # noqa: PLC0415
|
||||
|
||||
self.control_image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=8, do_convert_rgb=True, do_normalize=False
|
||||
)
|
||||
|
||||
# Create CUDA events
|
||||
self.events = {}
|
||||
for stage in ["clip", "denoise", "vae", "vae_encoder", "pil"]:
|
||||
for marker in ["start", "stop"]:
|
||||
self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
|
||||
self.markers = {}
|
||||
|
||||
def is_backend_tensorrt(self):
|
||||
return self.engine_type == EngineType.TRT
|
||||
|
||||
def set_scheduler(self, scheduler: str):
|
||||
if scheduler == self.current_scheduler:
|
||||
return
|
||||
|
||||
# Scheduler options
|
||||
sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012}
|
||||
if self.version in ("2.0", "2.1"):
|
||||
sched_opts["prediction_type"] = "v_prediction"
|
||||
else:
|
||||
sched_opts["prediction_type"] = "epsilon"
|
||||
|
||||
if scheduler == "DDIM":
|
||||
self.scheduler = DDIMScheduler(device=self.device, **sched_opts)
|
||||
elif scheduler == "EulerA":
|
||||
self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts)
|
||||
elif scheduler == "UniPC":
|
||||
self.scheduler = UniPCMultistepScheduler(device=self.device, **sched_opts)
|
||||
elif scheduler == "LCM":
|
||||
self.scheduler = LCMScheduler(device=self.device, **sched_opts)
|
||||
else:
|
||||
raise ValueError("Scheduler should be either DDIM, EulerA, UniPC or LCM")
|
||||
|
||||
self.current_scheduler = scheduler
|
||||
self.denoising_steps = None
|
||||
|
||||
def set_denoising_steps(self, denoising_steps: int):
|
||||
if not (self.denoising_steps == denoising_steps and isinstance(self.scheduler, DDIMScheduler)):
|
||||
self.scheduler.set_timesteps(denoising_steps)
|
||||
self.scheduler.configure()
|
||||
self.denoising_steps = denoising_steps
|
||||
|
||||
def load_resources(self, image_height, image_width, batch_size):
|
||||
# If engine is built with static input shape, call this only once after engine build.
|
||||
# Otherwise, it need be called before every inference run.
|
||||
self.backend.load_resources(image_height, image_width, batch_size)
|
||||
|
||||
def set_random_seed(self, seed):
|
||||
if isinstance(seed, int):
|
||||
self.generator.manual_seed(seed)
|
||||
else:
|
||||
self.generator.seed()
|
||||
|
||||
def get_current_seed(self):
|
||||
return self.generator.initial_seed()
|
||||
|
||||
def teardown(self):
|
||||
for e in self.events.values():
|
||||
cudart.cudaEventDestroy(e)
|
||||
|
||||
if self.backend:
|
||||
self.backend.teardown()
|
||||
|
||||
def run_engine(self, model_name, feed_dict):
|
||||
return self.backend.run_engine(model_name, feed_dict)
|
||||
|
||||
def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width):
|
||||
latents_dtype = torch.float16
|
||||
latents_shape = (batch_size, unet_channels, latent_height, latent_width)
|
||||
latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator)
|
||||
# Scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def initialize_timesteps(self, timesteps, strength):
|
||||
"""Initialize timesteps for refiner."""
|
||||
self.scheduler.set_timesteps(timesteps)
|
||||
offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
|
||||
init_timestep = int(timesteps * strength) + offset
|
||||
init_timestep = min(init_timestep, timesteps)
|
||||
t_start = max(timesteps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
return timesteps, t_start
|
||||
|
||||
def initialize_refiner(self, batch_size, image, strength):
|
||||
"""Add noise to a reference image."""
|
||||
# Initialize timesteps
|
||||
timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength)
|
||||
|
||||
latent_timestep = timesteps[:1].repeat(batch_size)
|
||||
|
||||
# Pre-process input image
|
||||
image = self.preprocess_images(batch_size, (image,))[0]
|
||||
|
||||
# VAE encode init image
|
||||
if image.shape[1] == 4:
|
||||
init_latents = image
|
||||
else:
|
||||
init_latents = self.encode_image(image)
|
||||
|
||||
# Add noise to latents using timesteps
|
||||
noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float16, generator=self.generator)
|
||||
|
||||
latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep)
|
||||
|
||||
return timesteps, t_start, latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self,
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
dtype,
|
||||
requires_aesthetics_score,
|
||||
):
|
||||
if requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
|
||||
else:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
|
||||
|
||||
return add_time_ids, add_neg_time_ids
|
||||
|
||||
def start_profile(self, name, color="blue"):
|
||||
if self.nvtx_profile:
|
||||
self.markers[name] = nvtx.start_range(message=name, color=color)
|
||||
event_name = name + "-start"
|
||||
if event_name in self.events:
|
||||
cudart.cudaEventRecord(self.events[event_name], 0)
|
||||
|
||||
def stop_profile(self, name):
|
||||
event_name = name + "-stop"
|
||||
if event_name in self.events:
|
||||
cudart.cudaEventRecord(self.events[event_name], 0)
|
||||
if self.nvtx_profile:
|
||||
nvtx.end_range(self.markers[name])
|
||||
|
||||
def preprocess_images(self, batch_size, images=()):
|
||||
self.start_profile("preprocess", color="pink")
|
||||
init_images = []
|
||||
for i in images:
|
||||
image = i.to(self.device)
|
||||
if image.shape[0] != batch_size:
|
||||
image = image.repeat(batch_size, 1, 1, 1)
|
||||
init_images.append(image)
|
||||
self.stop_profile("preprocess")
|
||||
return tuple(init_images)
|
||||
|
||||
def preprocess_controlnet_images(
|
||||
self, batch_size, images=None, do_classifier_free_guidance=True, height=1024, width=1024
|
||||
):
|
||||
"""
|
||||
Process a list of PIL.Image.Image as control images, and return a torch tensor.
|
||||
"""
|
||||
if images is None:
|
||||
return None
|
||||
self.start_profile("preprocess", color="pink")
|
||||
|
||||
if not self.pipeline_info.is_xl():
|
||||
images = [
|
||||
torch.from_numpy(
|
||||
(np.array(image.convert("RGB")).astype(np.float32) / 255.0)[..., None].transpose(3, 2, 0, 1)
|
||||
)
|
||||
.to(device=self.device, dtype=torch.float16)
|
||||
.repeat_interleave(batch_size, dim=0)
|
||||
for image in images
|
||||
]
|
||||
else:
|
||||
images = [
|
||||
self.control_image_processor.preprocess(image, height=height, width=width)
|
||||
.to(device=self.device, dtype=torch.float16)
|
||||
.repeat_interleave(batch_size, dim=0)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
images = [torch.cat([i] * 2) for i in images]
|
||||
images = torch.cat([image[None, ...] for image in images], dim=0)
|
||||
|
||||
self.stop_profile("preprocess")
|
||||
return images
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
encoder="clip",
|
||||
tokenizer=None,
|
||||
pooled_outputs=False,
|
||||
output_hidden_states=False,
|
||||
force_zeros_for_empty_prompt=False,
|
||||
do_classifier_free_guidance=True,
|
||||
dtype=torch.float16,
|
||||
):
|
||||
if tokenizer is None:
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
self.start_profile("clip", color="green")
|
||||
|
||||
def tokenize(prompt, output_hidden_states):
|
||||
text_input_ids = (
|
||||
tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
.input_ids.type(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
|
||||
hidden_states = None
|
||||
if self.engine_type == EngineType.TORCH:
|
||||
outputs = self.backend.engines[encoder](text_input_ids)
|
||||
text_embeddings = outputs[0]
|
||||
if output_hidden_states:
|
||||
hidden_states = outputs["last_hidden_state"]
|
||||
else:
|
||||
outputs = self.run_engine(encoder, {"input_ids": text_input_ids})
|
||||
text_embeddings = outputs["text_embeddings"]
|
||||
if output_hidden_states:
|
||||
hidden_states = outputs["hidden_states"]
|
||||
return text_embeddings, hidden_states
|
||||
|
||||
# Tokenize prompt
|
||||
text_embeddings, hidden_states = tokenize(prompt, output_hidden_states)
|
||||
|
||||
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
|
||||
text_embeddings = text_embeddings.clone()
|
||||
if hidden_states is not None:
|
||||
hidden_states = hidden_states.clone()
|
||||
|
||||
# Note: negative prompt embedding is not needed for SD XL when guidance <= 1
|
||||
if do_classifier_free_guidance:
|
||||
# For SD XL base, handle force_zeros_for_empty_prompt
|
||||
is_empty_negative_prompt = all(not i for i in negative_prompt)
|
||||
if force_zeros_for_empty_prompt and is_empty_negative_prompt:
|
||||
uncond_embeddings = torch.zeros_like(text_embeddings)
|
||||
if output_hidden_states:
|
||||
uncond_hidden_states = torch.zeros_like(hidden_states)
|
||||
else:
|
||||
# Tokenize negative prompt
|
||||
uncond_embeddings, uncond_hidden_states = tokenize(negative_prompt, output_hidden_states)
|
||||
|
||||
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
if output_hidden_states:
|
||||
hidden_states = torch.cat([uncond_hidden_states, hidden_states])
|
||||
|
||||
self.stop_profile("clip")
|
||||
|
||||
if pooled_outputs:
|
||||
# For text encoder in sdxl base
|
||||
return hidden_states.to(dtype=dtype), text_embeddings.to(dtype=dtype)
|
||||
|
||||
if output_hidden_states:
|
||||
# For text encoder 2 in sdxl base or refiner
|
||||
return hidden_states.to(dtype=dtype)
|
||||
|
||||
# For text encoder in sd 1.5
|
||||
return text_embeddings.to(dtype=dtype)
|
||||
|
||||
def denoise_latent(
|
||||
self,
|
||||
latents,
|
||||
text_embeddings,
|
||||
denoiser="unet",
|
||||
timesteps=None,
|
||||
step_offset=0,
|
||||
guidance=7.5,
|
||||
add_kwargs=None,
|
||||
):
|
||||
do_classifier_free_guidance = guidance > 1.0
|
||||
|
||||
self.start_profile("denoise", color="blue")
|
||||
|
||||
if not isinstance(timesteps, torch.Tensor):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
for step_index, timestep in enumerate(timesteps):
|
||||
# Expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, step_offset + step_index, timestep
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
if self.nvtx_profile:
|
||||
nvtx_unet = nvtx.start_range(message="unet", color="blue")
|
||||
|
||||
params = {
|
||||
"sample": latent_model_input,
|
||||
"timestep": timestep.to(latents.dtype),
|
||||
"encoder_hidden_states": text_embeddings,
|
||||
}
|
||||
|
||||
if add_kwargs:
|
||||
params.update(add_kwargs)
|
||||
|
||||
noise_pred = self.run_engine(denoiser, params)["latent"]
|
||||
|
||||
if self.nvtx_profile:
|
||||
nvtx.end_range(nvtx_unet)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if type(self.scheduler) is UniPCMultistepScheduler:
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
|
||||
elif type(self.scheduler) is LCMScheduler:
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents, generator=self.generator)[0]
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep)
|
||||
|
||||
# The actual number of steps. It might be different from denoising_steps.
|
||||
self.actual_steps = len(timesteps)
|
||||
|
||||
self.stop_profile("denoise")
|
||||
return latents
|
||||
|
||||
def encode_image(self, image):
|
||||
self.start_profile("vae_encoder", color="red")
|
||||
init_latents = self.run_engine("vae_encoder", {"images": image})["latent"]
|
||||
init_latents = self.vae_scaling_factor * init_latents
|
||||
self.stop_profile("vae_encoder")
|
||||
return init_latents
|
||||
|
||||
def decode_latent(self, latents):
|
||||
self.start_profile("vae", color="red")
|
||||
images = self.backend.vae_decode(latents)
|
||||
self.stop_profile("vae")
|
||||
return images
|
||||
|
||||
def print_summary(self, tic, toc, batch_size, vae_enc=False, pil=False) -> dict[str, Any]:
|
||||
throughput = batch_size / (toc - tic)
|
||||
latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1]
|
||||
latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1]
|
||||
latency_vae = cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1]
|
||||
latency_vae_encoder = (
|
||||
cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1]
|
||||
if vae_enc
|
||||
else None
|
||||
)
|
||||
latency_pil = cudart.cudaEventElapsedTime(self.events["pil-start"], self.events["pil-stop"])[1] if pil else None
|
||||
|
||||
latency = (toc - tic) * 1000.0
|
||||
|
||||
print("|----------------|--------------|")
|
||||
print("| {:^14} | {:^12} |".format("Module", "Latency"))
|
||||
print("|----------------|--------------|")
|
||||
if vae_enc:
|
||||
print("| {:^14} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder))
|
||||
print("| {:^14} | {:>9.2f} ms |".format("CLIP", latency_clip))
|
||||
print(
|
||||
"| {:^14} | {:>9.2f} ms |".format(
|
||||
"UNet" + ("+CNet" if self.pipeline_info.controlnet else "") + " x " + str(self.actual_steps),
|
||||
latency_unet,
|
||||
)
|
||||
)
|
||||
print("| {:^14} | {:>9.2f} ms |".format("VAE-Dec", latency_vae))
|
||||
pipeline = "Refiner" if self.pipeline_info.is_xl_refiner() else "Pipeline"
|
||||
if pil:
|
||||
print("| {:^14} | {:>9.2f} ms |".format("PIL", latency_pil))
|
||||
print("|----------------|--------------|")
|
||||
print(f"| {pipeline:^14} | {latency:>9.2f} ms |")
|
||||
print("|----------------|--------------|")
|
||||
print(f"Throughput: {throughput:.2f} image/s")
|
||||
|
||||
perf_data = {
|
||||
"latency_clip": latency_clip,
|
||||
"latency_unet": latency_unet,
|
||||
"latency_vae": latency_vae,
|
||||
"latency_pil": latency_pil,
|
||||
"latency": latency,
|
||||
"throughput": throughput,
|
||||
}
|
||||
if vae_enc:
|
||||
perf_data["latency_vae_encoder"] = latency_vae_encoder
|
||||
return perf_data
|
||||
|
||||
@staticmethod
|
||||
def pt_to_pil(images):
|
||||
images = (
|
||||
((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
|
||||
)
|
||||
return [Image.fromarray(images[i]) for i in range(images.shape[0])]
|
||||
|
||||
@staticmethod
|
||||
def pt_to_numpy(images: torch.FloatTensor):
|
||||
"""
|
||||
Convert a PyTorch tensor to a NumPy image.
|
||||
"""
|
||||
return ((images + 1) / 2).clamp(0, 1).detach().permute(0, 2, 3, 1).float().cpu().numpy()
|
||||
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
data = {
|
||||
"actual_steps": self.actual_steps,
|
||||
"seed": self.get_current_seed(),
|
||||
"name": self.pipeline_info.name(),
|
||||
"custom_vae": self.pipeline_info.custom_fp16_vae(),
|
||||
"custom_unet": self.pipeline_info.custom_unet(),
|
||||
}
|
||||
|
||||
if self.engine_type == EngineType.ORT_CUDA:
|
||||
for engine_name, engine in self.backend.engines.items():
|
||||
data.update(engine.metadata(engine_name))
|
||||
|
||||
return data
|
||||
|
||||
def save_images(self, images: list, prompt: list[str], negative_prompt: list[str], metadata: dict[str, Any]):
|
||||
session_id = str(random.randint(1000, 9999))
|
||||
for i, image in enumerate(images):
|
||||
seed = str(self.get_current_seed())
|
||||
prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20]
|
||||
parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)]
|
||||
image_path = os.path.join(self.output_dir, "-".join(parts) + ".png")
|
||||
print(f"Saving image {i + 1} / {len(images)} to: {image_path}")
|
||||
|
||||
from PIL import PngImagePlugin # noqa: PLC0415
|
||||
|
||||
info = PngImagePlugin.PngInfo()
|
||||
for k, v in metadata.items():
|
||||
info.add_text(k, str(v))
|
||||
info.add_text("prompt", prompt[i])
|
||||
info.add_text("negative_prompt", negative_prompt[i])
|
||||
|
||||
image.save(image_path, "PNG", pnginfo=info)
|
||||
|
||||
def _infer(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image_height,
|
||||
image_width,
|
||||
denoising_steps=30,
|
||||
guidance=5.0,
|
||||
seed=None,
|
||||
image=None,
|
||||
strength=0.3,
|
||||
controlnet_images=None,
|
||||
controlnet_scales=None,
|
||||
show_latency=False,
|
||||
output_type="pil",
|
||||
):
|
||||
if show_latency:
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
batch_size = len(prompt)
|
||||
|
||||
self.set_denoising_steps(denoising_steps)
|
||||
self.set_random_seed(seed)
|
||||
|
||||
timesteps = None
|
||||
step_offset = 0
|
||||
with torch.inference_mode(), torch.autocast("cuda"):
|
||||
if image is not None:
|
||||
timesteps, step_offset, latents = self.initialize_refiner(
|
||||
batch_size=batch_size,
|
||||
image=image,
|
||||
strength=strength,
|
||||
)
|
||||
else:
|
||||
# Pre-initialize latents
|
||||
latents = self.initialize_latents(
|
||||
batch_size=batch_size,
|
||||
unet_channels=4,
|
||||
latent_height=(image_height // 8),
|
||||
latent_width=(image_width // 8),
|
||||
)
|
||||
|
||||
do_classifier_free_guidance = guidance > 1.0
|
||||
if not self.pipeline_info.is_xl():
|
||||
denoiser = "unet"
|
||||
text_embeddings = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
dtype=latents.dtype,
|
||||
)
|
||||
add_kwargs = {}
|
||||
else:
|
||||
denoiser = "unetxl"
|
||||
|
||||
# Time embeddings
|
||||
original_size = (image_height, image_width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
target_size = (image_height, image_width)
|
||||
aesthetic_score = 6.0
|
||||
negative_aesthetic_score = 2.5
|
||||
add_time_ids, add_negative_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
dtype=latents.dtype,
|
||||
requires_aesthetics_score=self.pipeline_info.is_xl_refiner(),
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
add_time_ids = torch.cat([add_negative_time_ids, add_time_ids], dim=0)
|
||||
add_time_ids = add_time_ids.to(device=self.device).repeat(batch_size, 1)
|
||||
|
||||
if self.pipeline_info.is_xl_refiner():
|
||||
# CLIP text encoder 2
|
||||
text_embeddings, pooled_embeddings2 = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
encoder="clip2",
|
||||
tokenizer=self.tokenizer2,
|
||||
pooled_outputs=True,
|
||||
output_hidden_states=True,
|
||||
dtype=latents.dtype,
|
||||
)
|
||||
add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
|
||||
else: # XL Base
|
||||
# CLIP text encoder
|
||||
text_embeddings = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
encoder="clip",
|
||||
tokenizer=self.tokenizer,
|
||||
output_hidden_states=True,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
dtype=latents.dtype,
|
||||
)
|
||||
# CLIP text encoder 2
|
||||
text_embeddings2, pooled_embeddings2 = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
encoder="clip2",
|
||||
tokenizer=self.tokenizer2,
|
||||
pooled_outputs=True,
|
||||
output_hidden_states=True,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
dtype=latents.dtype,
|
||||
)
|
||||
|
||||
# Merged text embeddings
|
||||
text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1)
|
||||
|
||||
add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
|
||||
|
||||
if self.pipeline_info.controlnet:
|
||||
controlnet_images = self.preprocess_controlnet_images(
|
||||
latents.shape[0],
|
||||
controlnet_images,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
height=image_height,
|
||||
width=image_width,
|
||||
)
|
||||
add_kwargs.update(
|
||||
{
|
||||
"controlnet_images": controlnet_images,
|
||||
"controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device),
|
||||
}
|
||||
)
|
||||
|
||||
# UNet denoiser
|
||||
latents = self.denoise_latent(
|
||||
latents,
|
||||
text_embeddings,
|
||||
timesteps=timesteps,
|
||||
step_offset=step_offset,
|
||||
denoiser=denoiser,
|
||||
guidance=guidance,
|
||||
add_kwargs=add_kwargs,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
# VAE decode latent
|
||||
if output_type == "latent":
|
||||
images = latents
|
||||
else:
|
||||
images = self.decode_latent(latents / self.vae_scaling_factor)
|
||||
if output_type == "pil":
|
||||
self.start_profile("pil", color="green")
|
||||
images = self.pt_to_pil(images)
|
||||
self.stop_profile("pil")
|
||||
|
||||
perf_data = None
|
||||
if show_latency:
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.perf_counter()
|
||||
perf_data = self.print_summary(
|
||||
start_time, end_time, batch_size, vae_enc=self.pipeline_info.is_xl_refiner(), pil=(output_type == "pil")
|
||||
)
|
||||
|
||||
return images, perf_data
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt: list[str],
|
||||
negative_prompt: list[str],
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
denoising_steps: int = 30,
|
||||
guidance: float = 5.0,
|
||||
seed: int | None = None,
|
||||
image: torch.Tensor | None = None,
|
||||
strength: float = 0.3,
|
||||
controlnet_images: torch.Tensor | None = None,
|
||||
controlnet_scales: torch.Tensor | None = None,
|
||||
show_latency: bool = False,
|
||||
output_type: str = "pil",
|
||||
deterministic: bool = False,
|
||||
):
|
||||
"""
|
||||
Run the diffusion pipeline.
|
||||
|
||||
Args:
|
||||
prompt (List[str]):
|
||||
The text prompt to guide image generation.
|
||||
negative_prompt (List[str]):
|
||||
The prompt not to guide the image generation.
|
||||
image_height (int):
|
||||
Height (in pixels) of the image to be generated. Must be a multiple of 8.
|
||||
image_width (int):
|
||||
Width (in pixels) of the image to be generated. Must be a multiple of 8.
|
||||
denoising_steps (int):
|
||||
Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference.
|
||||
guidance (float):
|
||||
Higher guidance scale encourages to generate images that are closely linked to the text prompt.
|
||||
seed (int):
|
||||
Seed for the random generator
|
||||
image (tuple[torch.Tensor]):
|
||||
Reference image.
|
||||
strength (float):
|
||||
Indicates extent to transform the reference image, which is used as a starting point,
|
||||
and more noise is added the higher the strength.
|
||||
show_latency (bool):
|
||||
Whether return latency data.
|
||||
output_type (str):
|
||||
It can be "latent", "pt" or "pil".
|
||||
"""
|
||||
if deterministic:
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
if self.is_backend_tensorrt():
|
||||
import tensorrt as trt # noqa: PLC0415
|
||||
from trt_utilities import TRT_LOGGER # noqa: PLC0415
|
||||
|
||||
with trt.Runtime(TRT_LOGGER):
|
||||
return self._infer(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image_height,
|
||||
image_width,
|
||||
denoising_steps=denoising_steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=image,
|
||||
strength=strength,
|
||||
controlnet_images=controlnet_images,
|
||||
controlnet_scales=controlnet_scales,
|
||||
show_latency=show_latency,
|
||||
output_type=output_type,
|
||||
)
|
||||
else:
|
||||
return self._infer(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image_height,
|
||||
image_width,
|
||||
denoising_steps=denoising_steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=image,
|
||||
strength=strength,
|
||||
controlnet_images=controlnet_images,
|
||||
controlnet_scales=controlnet_scales,
|
||||
show_latency=show_latency,
|
||||
output_type=output_type,
|
||||
)
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import tensorrt as trt
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
|
||||
|
||||
|
||||
def init_trt_plugins():
|
||||
# Register TensorRT plugins
|
||||
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+318
@@ -0,0 +1,318 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from benchmark_helper import (
|
||||
Precision,
|
||||
create_onnxruntime_session,
|
||||
prepare_environment,
|
||||
setup_logger,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
from t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS, T5Helper
|
||||
from transformers import MT5Config, T5Config
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
pretrained_models = PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_name_or_path",
|
||||
required=False,
|
||||
default=PRETRAINED_T5_MODELS[0],
|
||||
type=str,
|
||||
help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
required=False,
|
||||
type=str,
|
||||
default="t5",
|
||||
choices=["t5", "mt5"],
|
||||
help="Model type: either t5 (default) or mt5",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "cache_models"),
|
||||
help="Directory to cache pre-trained models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "onnx_models"),
|
||||
help="Output directory",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--optimize_onnx",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use optimizer.py to optimize onnx model",
|
||||
)
|
||||
parser.set_defaults(optimize_onnx=False)
|
||||
|
||||
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
|
||||
parser.set_defaults(use_gpu=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
required=False,
|
||||
type=str,
|
||||
default=Precision.FLOAT32.value,
|
||||
choices=[Precision.FLOAT32.value, Precision.FLOAT16.value],
|
||||
help="Precision of model to run. fp32 for full precision, fp16 for half precision",
|
||||
)
|
||||
|
||||
parser.add_argument("--verbose", required=False, action="store_true")
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
|
||||
parser.set_defaults(use_external_data_format=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--use_decoder_start_token",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use config.decoder_start_token_id. Otherwise, add an extra graph input for decoder_input_ids.",
|
||||
)
|
||||
parser.set_defaults(use_decoder_start_token=False)
|
||||
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--overwrite",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="overwrite existing ONNX model",
|
||||
)
|
||||
parser.set_defaults(overwrite=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable_auto_mixed_precision",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="do not use auto mixed precision conversion",
|
||||
)
|
||||
parser.set_defaults(disable_auto_mixed_precision=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--force_fp16_io",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Force to convert all float inputs and outputs to fp16 when precision is fp16.",
|
||||
)
|
||||
parser.set_defaults(force_fp16_io=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_int64_inputs",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.",
|
||||
)
|
||||
parser.set_defaults(use_int64_inputs=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--state_dict_path",
|
||||
type=str,
|
||||
default="",
|
||||
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder_decoder_init",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Combine encoder and decoder kv cache initialization into one model. It is legacy format that will be deprecated.",
|
||||
)
|
||||
parser.set_defaults(encoder_decoder_init=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def export_onnx_models(
|
||||
model_name_or_path: str,
|
||||
cache_dir: str,
|
||||
output_dir: str,
|
||||
use_gpu: bool = False,
|
||||
use_external_data_format: bool = False,
|
||||
optimize_onnx: bool = False,
|
||||
precision: str = Precision.FLOAT32.value,
|
||||
verbose: bool = False,
|
||||
use_decoder_start_token: bool = False,
|
||||
overwrite: bool = False,
|
||||
disable_auto_mixed_precision: bool = False,
|
||||
use_int32_inputs: bool = True,
|
||||
model_type: str = "t5",
|
||||
state_dict_path: str = "",
|
||||
encoder_decoder_init: bool = False,
|
||||
force_fp16_io: bool = False,
|
||||
shape_infer_before_optimization: bool = False,
|
||||
):
|
||||
assert precision in [Precision.FLOAT32.value, Precision.FLOAT16.value], (
|
||||
f"Invalid precision: {precision}. Use 'fp32' or 'fp16'."
|
||||
)
|
||||
device = torch.device("cuda:0" if use_gpu else "cpu")
|
||||
|
||||
models = T5Helper.load_model(
|
||||
model_name_or_path,
|
||||
cache_dir,
|
||||
device,
|
||||
model_type,
|
||||
state_dict_path,
|
||||
encoder_decoder_init=encoder_decoder_init,
|
||||
)
|
||||
config: T5Config | MT5Config = models["decoder"].config
|
||||
|
||||
if (not use_external_data_format) and (config.num_layers > 24):
|
||||
logger.info("Try use_external_data_format when model size > 2GB")
|
||||
|
||||
output_paths = []
|
||||
for name, model in models.items():
|
||||
model.to(device)
|
||||
filename_suffix = "_" + name
|
||||
|
||||
onnx_path = T5Helper.get_onnx_path(
|
||||
output_dir,
|
||||
model_name_or_path,
|
||||
suffix=filename_suffix,
|
||||
new_folder=False,
|
||||
)
|
||||
|
||||
if overwrite or not os.path.exists(onnx_path):
|
||||
logger.info(f"Exporting ONNX model to {onnx_path}")
|
||||
# We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
|
||||
cloned_model = copy.deepcopy(model).to(device)
|
||||
T5Helper.export_onnx(
|
||||
cloned_model,
|
||||
device,
|
||||
onnx_path,
|
||||
verbose,
|
||||
use_external_data_format,
|
||||
use_decoder_input_ids=not use_decoder_start_token,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
|
||||
|
||||
# Optimize ONNX graph.
|
||||
# The precision shall be compared with string value. It is because the Precision enum loaded from local file
|
||||
# (like by transformers test in CI pipeline) are not same as Precision enum from package.
|
||||
if optimize_onnx or precision != Precision.FLOAT32.value:
|
||||
onnx_shape_path = None
|
||||
if shape_infer_before_optimization:
|
||||
onnx_shape_path = T5Helper.get_onnx_path(
|
||||
output_dir,
|
||||
model_name_or_path,
|
||||
suffix=filename_suffix + "_shape",
|
||||
new_folder=False,
|
||||
)
|
||||
infer_shapes_path(onnx_path, onnx_shape_path)
|
||||
|
||||
output_path = T5Helper.get_onnx_path(
|
||||
output_dir,
|
||||
model_name_or_path,
|
||||
suffix=filename_suffix + "_" + str(precision),
|
||||
new_folder=False,
|
||||
)
|
||||
|
||||
if overwrite or not os.path.exists(output_path):
|
||||
logger.info(f"Optimizing model to {output_path}")
|
||||
T5Helper.optimize_onnx(
|
||||
onnx_shape_path or onnx_path,
|
||||
output_path,
|
||||
precision == Precision.FLOAT16.value,
|
||||
config.num_heads,
|
||||
config.hidden_size,
|
||||
use_external_data_format,
|
||||
auto_mixed_precision=not disable_auto_mixed_precision,
|
||||
use_gpu=use_gpu,
|
||||
force_fp16_io=force_fp16_io,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Skip optimizing: existed ONNX model {output_path}")
|
||||
else:
|
||||
output_path = onnx_path
|
||||
|
||||
ort_session = create_onnxruntime_session(
|
||||
output_path,
|
||||
use_gpu=use_gpu,
|
||||
verbose=verbose,
|
||||
)
|
||||
if ort_session is None:
|
||||
break
|
||||
|
||||
with torch.no_grad():
|
||||
max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
||||
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
|
||||
|
||||
# The threshold cannot apply to fp16 model, which need a larger threshold.
|
||||
if precision == Precision.FLOAT32.value and max_diff > 1e-4:
|
||||
logger.warning("PyTorch and OnnxRuntime results are NOT close")
|
||||
|
||||
output_paths.append(output_path)
|
||||
|
||||
return output_paths
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
|
||||
setup_logger(args.verbose)
|
||||
|
||||
logger.info(f"Arguments:{args}")
|
||||
|
||||
cache_dir = args.cache_dir
|
||||
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
|
||||
prepare_environment(cache_dir, output_dir, args.use_gpu)
|
||||
|
||||
if args.precision != Precision.FLOAT32.value:
|
||||
assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
|
||||
|
||||
if args.precision == Precision.FLOAT16.value:
|
||||
assert args.use_gpu, "fp16 requires --use_gpu"
|
||||
|
||||
output_paths = export_onnx_models(
|
||||
args.model_name_or_path,
|
||||
cache_dir,
|
||||
output_dir,
|
||||
args.use_gpu,
|
||||
args.use_external_data_format,
|
||||
args.optimize_onnx,
|
||||
args.precision,
|
||||
args.verbose,
|
||||
args.use_decoder_start_token,
|
||||
args.overwrite,
|
||||
args.disable_auto_mixed_precision,
|
||||
not args.use_int64_inputs,
|
||||
args.model_type,
|
||||
encoder_decoder_init=args.encoder_decoder_init,
|
||||
force_fp16_io=args.force_fp16_io,
|
||||
)
|
||||
|
||||
logger.info(f"Done! Outputs: {output_paths}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+437
@@ -0,0 +1,437 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
import onnx
|
||||
import torch
|
||||
from io_binding_helper import TypeHelper
|
||||
from onnx_model import OnnxModel
|
||||
from past_helper import PastKeyValuesHelper
|
||||
from t5_encoder import T5EncoderInputs
|
||||
from torch_onnx_export_helper import torch_onnx_export
|
||||
from transformers import MT5Config, T5Config
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class T5DecoderInit(torch.nn.Module):
|
||||
"""A T5 decoder with LM head to create initial past key values.
|
||||
This model is only called once during starting decoding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoder: torch.nn.Module,
|
||||
lm_head: torch.nn.Module,
|
||||
config: T5Config | MT5Config,
|
||||
decoder_start_token_id: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.lm_head = lm_head
|
||||
self.config = config
|
||||
self.decoder_start_token_id = (
|
||||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
||||
)
|
||||
self.tie_word_embeddings = (
|
||||
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_input_ids: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
):
|
||||
if decoder_input_ids is None:
|
||||
batch_size = encoder_attention_mask.shape[0]
|
||||
decoder_input_ids = (
|
||||
torch.ones(
|
||||
(batch_size, 1),
|
||||
dtype=torch.long,
|
||||
device=encoder_attention_mask.device,
|
||||
)
|
||||
* self.decoder_start_token_id
|
||||
)
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = decoder_outputs.last_hidden_state
|
||||
present_key_values = decoder_outputs.past_key_values
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
|
||||
return lm_logits, past_self, past_cross
|
||||
|
||||
|
||||
class T5Decoder(torch.nn.Module):
|
||||
"""A T5 decoder with LM head and past key values"""
|
||||
|
||||
def __init__(self, decoder, lm_head, config):
|
||||
super().__init__()
|
||||
self.decoder = decoder
|
||||
self.lm_head = lm_head
|
||||
self.config = config
|
||||
self.tie_word_embeddings = (
|
||||
self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
|
||||
)
|
||||
|
||||
def forward(self, decoder_input_ids, encoder_attention_mask, *past):
|
||||
num_decoder_layers = self.config.num_decoder_layers
|
||||
past_key_values = PastKeyValuesHelper.group_by_layer(past, num_decoder_layers)
|
||||
|
||||
# This is a hack since only the third dimension of encoder_hidden_states is used here
|
||||
dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
encoder_hidden_states=dummy_encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = decoder_outputs.last_hidden_state
|
||||
present_key_values = decoder_outputs.past_key_values
|
||||
|
||||
if self.tie_word_embeddings:
|
||||
sequence_output = sequence_output * (self.config.d_model**-0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
|
||||
|
||||
# Do not return present_cross since they are identical to corresponding past_cross input
|
||||
return lm_logits, present_self
|
||||
|
||||
|
||||
class T5DecoderInputs:
|
||||
def __init__(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
encoder_attention_mask,
|
||||
past_key_values=None,
|
||||
):
|
||||
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
|
||||
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
|
||||
self.past_key_values: list[torch.FloatTensor] | list[torch.HalfTensor] | None = past_key_values
|
||||
|
||||
@staticmethod
|
||||
def create_dummy(
|
||||
config: T5Config | MT5Config,
|
||||
batch_size: int,
|
||||
encode_sequence_length: int,
|
||||
past_decode_sequence_length: int,
|
||||
device: torch.device,
|
||||
float16: bool = False,
|
||||
use_int32_inputs: bool = False,
|
||||
): # -> T5DecoderInputs:
|
||||
"""Create dummy inputs for T5Decoder.
|
||||
|
||||
Args:
|
||||
decoder: decoder
|
||||
batch_size (int): batch size
|
||||
encode_sequence_length (int): sequence length of input_ids for encoder
|
||||
past_decode_sequence_length (int): past sequence length of input_ids for decoder
|
||||
device (torch.device): device of output tensors
|
||||
float16 (bool): whether the model uses float32 or float16 in input
|
||||
use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
|
||||
|
||||
Returns:
|
||||
T5DecoderInputs: dummy inputs for decoder
|
||||
"""
|
||||
num_attention_heads: int = config.num_heads
|
||||
num_layers: int = config.num_decoder_layers
|
||||
vocab_size: int = config.vocab_size
|
||||
|
||||
# Do not use head_size = hidden_size / num_attention_heads here.
|
||||
# For example, mt5-small, d_model=512 and num_heads=6
|
||||
head_size: int = config.d_kv
|
||||
|
||||
sequence_length: int = 1 # fixed for decoding
|
||||
decoder_input_ids = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size - 1,
|
||||
size=(batch_size, sequence_length),
|
||||
dtype=(torch.int32 if use_int32_inputs else torch.int64),
|
||||
device=device,
|
||||
)
|
||||
|
||||
encoder_inputs = T5EncoderInputs.create_dummy(
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
vocab_size,
|
||||
device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
|
||||
float_type = torch.float16 if float16 else torch.float32
|
||||
|
||||
if past_decode_sequence_length > 0:
|
||||
self_attention_past_shape = [
|
||||
batch_size,
|
||||
num_attention_heads,
|
||||
past_decode_sequence_length,
|
||||
head_size,
|
||||
]
|
||||
cross_attention_past_shape = [
|
||||
batch_size,
|
||||
num_attention_heads,
|
||||
encode_sequence_length,
|
||||
head_size,
|
||||
]
|
||||
|
||||
past = []
|
||||
for _ in range(2 * num_layers):
|
||||
past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
|
||||
|
||||
for _ in range(2 * num_layers):
|
||||
past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
|
||||
else:
|
||||
past = None
|
||||
|
||||
return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past)
|
||||
|
||||
def to_list(self) -> list:
|
||||
input_list = [
|
||||
self.decoder_input_ids,
|
||||
self.encoder_attention_mask,
|
||||
]
|
||||
if self.past_key_values:
|
||||
input_list.extend(self.past_key_values)
|
||||
return input_list
|
||||
|
||||
def to_fp32(self):
|
||||
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
|
||||
return T5DecoderInputs(
|
||||
self.decoder_input_ids.clone(),
|
||||
self.encoder_attention_mask.clone(),
|
||||
past,
|
||||
)
|
||||
|
||||
|
||||
class T5DecoderHelper:
|
||||
@staticmethod
|
||||
def export_onnx(
|
||||
decoder: T5Decoder | T5DecoderInit,
|
||||
device: torch.device,
|
||||
onnx_model_path: str,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_int32_inputs: bool = False,
|
||||
):
|
||||
"""Export decoder to ONNX
|
||||
|
||||
Args:
|
||||
decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
|
||||
device (torch.device): device of decoder object
|
||||
onnx_model_path (str): onnx path
|
||||
verbose (bool, optional): print verbose information. Defaults to True.
|
||||
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
||||
use_int32_inputs (bool, optional): use int32 inputs
|
||||
"""
|
||||
assert isinstance(decoder, (T5Decoder, T5DecoderInit))
|
||||
|
||||
inputs = T5DecoderInputs.create_dummy(
|
||||
decoder.config,
|
||||
batch_size=2,
|
||||
encode_sequence_length=3,
|
||||
past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
|
||||
device=device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
input_list = inputs.to_list()
|
||||
|
||||
num_decoder_layers = decoder.config.num_decoder_layers
|
||||
|
||||
past_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=False)
|
||||
present_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=True)
|
||||
present_self_names = present_names[: 2 * num_decoder_layers]
|
||||
|
||||
input_past_names = past_names if isinstance(decoder, T5Decoder) else []
|
||||
output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
|
||||
output_names = ["logits", *output_present_names]
|
||||
|
||||
# Shape of input tensors (sequence_length==1):
|
||||
# input_ids: (batch_size, sequence_length)
|
||||
# encoder_attention_mask: (batch_size, encode_sequence_length)
|
||||
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
|
||||
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
||||
|
||||
# Shape of output tensors:
|
||||
# logits: (batch_size, sequence_length, vocab_size)
|
||||
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
|
||||
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
||||
|
||||
input_names = ["input_ids"]
|
||||
input_names.append("encoder_attention_mask")
|
||||
input_names.extend(input_past_names)
|
||||
|
||||
dynamic_axes = {
|
||||
"input_ids": {
|
||||
0: "batch_size",
|
||||
# 1: 'sequence_length'
|
||||
},
|
||||
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
|
||||
"encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
|
||||
"logits": {
|
||||
0: "batch_size",
|
||||
# 1: 'sequence_length'
|
||||
},
|
||||
}
|
||||
|
||||
for name in input_past_names:
|
||||
dynamic_axes[name] = {
|
||||
0: "batch_size",
|
||||
2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
|
||||
}
|
||||
|
||||
for name in output_present_names:
|
||||
if "cross" in name:
|
||||
dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
|
||||
else: # self attention past state
|
||||
if isinstance(decoder, T5Decoder):
|
||||
dynamic_axes[name] = {
|
||||
0: "batch_size",
|
||||
2: "past_decode_sequence_length + 1",
|
||||
}
|
||||
else:
|
||||
dynamic_axes[name] = {
|
||||
0: "batch_size",
|
||||
# 2: 'sequence_length'
|
||||
}
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
|
||||
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch_onnx_export(
|
||||
decoder,
|
||||
args=tuple(input_list),
|
||||
f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=12,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_data_format,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
if use_external_data_format:
|
||||
model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
|
||||
OnnxModel.save(
|
||||
model,
|
||||
onnx_model_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def onnxruntime_inference(ort_session, inputs: T5DecoderInputs):
|
||||
"""Run inference of ONNX model."""
|
||||
logger.debug("start onnxruntime_inference")
|
||||
|
||||
ort_inputs = {
|
||||
"input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
|
||||
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
|
||||
}
|
||||
|
||||
if inputs.past_key_values:
|
||||
assert len(inputs.past_key_values) % 4 == 0
|
||||
num_layers = int(len(inputs.past_key_values) / 4)
|
||||
past_names = PastKeyValuesHelper.get_past_names(num_layers)
|
||||
for i, past_tensor in enumerate(inputs.past_key_values):
|
||||
ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
|
||||
|
||||
ort_outputs = ort_session.run(None, ort_inputs)
|
||||
return ort_outputs
|
||||
|
||||
@staticmethod
|
||||
def verify_onnx(
|
||||
model: T5Decoder | T5DecoderInit,
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
use_int32_inputs: bool,
|
||||
max_cases: int = 4,
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
|
||||
|
||||
test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
|
||||
test_cases_max_diff = []
|
||||
for (
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
past_decode_sequence_length,
|
||||
) in test_cases[:max_cases]:
|
||||
if isinstance(model, T5DecoderInit):
|
||||
past_decode_sequence_length = 0 # noqa: PLW2901
|
||||
|
||||
inputs = T5DecoderInputs.create_dummy(
|
||||
model.config,
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
past_decode_sequence_length,
|
||||
device=device,
|
||||
float16=float16,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
|
||||
# We use fp32 PyTroch model as baseline even when ONNX model is fp16
|
||||
input_list = inputs.to_fp32().to_list()
|
||||
|
||||
# Run inference of PyTorch model
|
||||
with torch.no_grad():
|
||||
torch_outputs = model(*input_list)
|
||||
|
||||
ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
|
||||
num_decoder_layers = model.config.num_decoder_layers
|
||||
|
||||
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
|
||||
max_diff_all = max_diff
|
||||
logger.debug(f"logits max_diff={max_diff}")
|
||||
|
||||
for i in range(2 * num_decoder_layers):
|
||||
max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
|
||||
logger.debug(f"self attention past state {i} max_diff={max_diff}")
|
||||
max_diff_all = max(max_diff_all, max_diff)
|
||||
|
||||
if isinstance(model, T5DecoderInit):
|
||||
for i in range(2 * num_decoder_layers):
|
||||
max_diff = numpy.amax(
|
||||
numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * num_decoder_layers + i])
|
||||
)
|
||||
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
|
||||
max_diff_all = max(max_diff_all, max_diff)
|
||||
|
||||
test_cases_max_diff.append(max_diff_all)
|
||||
logger.info(
|
||||
"batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
past_decode_sequence_length,
|
||||
max_diff_all,
|
||||
)
|
||||
|
||||
return max_diff_all
|
||||
+70
@@ -0,0 +1,70 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from transformers import MT5Config, T5Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class T5Encoder(torch.nn.Module):
|
||||
"""T5 encoder outputs only the last hidden state"""
|
||||
|
||||
def __init__(self, encoder, config: T5Config | MT5Config):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.config = config
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
return self.encoder(input_ids, attention_mask)[0]
|
||||
|
||||
|
||||
class T5EncoderInputs:
|
||||
def __init__(self, input_ids, attention_mask):
|
||||
self.input_ids: torch.LongTensor = input_ids
|
||||
self.attention_mask: torch.LongTensor = attention_mask
|
||||
|
||||
@staticmethod
|
||||
def create_dummy(
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
use_int32_inputs: bool = False,
|
||||
): # -> T5EncoderInputs
|
||||
"""Create dummy inputs for T5 encoder.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size
|
||||
sequence_length (int): sequence length
|
||||
vocab_size (int): vocabulary size
|
||||
device (torch.device): device of output tensors
|
||||
|
||||
Returns:
|
||||
T5EncoderInputs: dummy inputs for encoder
|
||||
"""
|
||||
dtype = torch.int32 if use_int32_inputs else torch.int64
|
||||
|
||||
input_ids = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size - 1,
|
||||
size=(batch_size, sequence_length),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
|
||||
if sequence_length >= 2:
|
||||
for i in range(batch_size):
|
||||
padding_position = random.randint(0, sequence_length - 1)
|
||||
attention_mask[i, :padding_position] = 0
|
||||
return T5EncoderInputs(input_ids, attention_mask)
|
||||
|
||||
def to_list(self) -> list:
|
||||
input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
|
||||
return input_list
|
||||
+361
@@ -0,0 +1,361 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy
|
||||
import onnx
|
||||
import torch
|
||||
from onnx_model import OnnxModel
|
||||
from past_helper import PastKeyValuesHelper
|
||||
from t5_decoder import T5DecoderInit
|
||||
from t5_encoder import T5Encoder, T5EncoderInputs
|
||||
from torch_onnx_export_helper import torch_onnx_export
|
||||
from transformers import MT5Config, T5Config
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class T5EncoderDecoderInit(torch.nn.Module):
|
||||
"""A combination of T5Encoder and T5DecoderInit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: torch.nn.Module,
|
||||
decoder: torch.nn.Module,
|
||||
lm_head: torch.nn.Linear,
|
||||
config: T5Config | MT5Config,
|
||||
decoder_start_token_id: int | None = None,
|
||||
output_cross_only: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config: T5Config | MT5Config = config
|
||||
self.t5_encoder = T5Encoder(encoder, config)
|
||||
self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id)
|
||||
self.output_cross_only = output_cross_only
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_input_ids: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
decoder_input_ids: torch.Tensor | None = None,
|
||||
):
|
||||
encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask)
|
||||
|
||||
lm_logits, past_self, past_cross = self.t5_decoder_init(
|
||||
decoder_input_ids, encoder_attention_mask, encoder_hidden_states
|
||||
)
|
||||
|
||||
if self.output_cross_only:
|
||||
return past_cross
|
||||
else:
|
||||
return lm_logits, encoder_hidden_states, past_self, past_cross
|
||||
|
||||
|
||||
class T5EncoderDecoderInitInputs:
|
||||
def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
|
||||
self.encoder_input_ids: torch.LongTensor = encoder_input_ids
|
||||
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
|
||||
self.decoder_input_ids: torch.LongTensor | None = decoder_input_ids
|
||||
|
||||
@staticmethod
|
||||
def create_dummy(
|
||||
config: T5Config | MT5Config,
|
||||
batch_size: int,
|
||||
encode_sequence_length: int,
|
||||
use_decoder_input_ids: int,
|
||||
device: torch.device,
|
||||
use_int32_inputs: bool = False,
|
||||
): # -> T5EncoderDecoderInitInputs:
|
||||
encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
config.vocab_size,
|
||||
device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
decoder_input_ids = None
|
||||
if use_decoder_input_ids:
|
||||
dtype = torch.int32 if use_int32_inputs else torch.int64
|
||||
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
|
||||
|
||||
return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
|
||||
|
||||
def to_list(self) -> list:
|
||||
input_list = [self.encoder_input_ids, self.encoder_attention_mask]
|
||||
if self.decoder_input_ids is not None:
|
||||
input_list.append(self.decoder_input_ids)
|
||||
return input_list
|
||||
|
||||
|
||||
class T5EncoderDecoderInitHelper:
|
||||
@staticmethod
|
||||
def export_onnx(
|
||||
model: T5EncoderDecoderInit,
|
||||
device: torch.device,
|
||||
onnx_model_path: str,
|
||||
use_decoder_input_ids: bool = True,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_int32_inputs: bool = False,
|
||||
):
|
||||
"""Export decoder to ONNX
|
||||
|
||||
Args:
|
||||
model (T5EncoderDecoderInit): the model to export
|
||||
device (torch.device): device of decoder object
|
||||
onnx_model_path (str): onnx path
|
||||
verbose (bool, optional): print verbose information. Defaults to True.
|
||||
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
||||
use_int32_inputs (bool, optional): use int32 instead of int64 for integer inputs. Defaults to False.
|
||||
"""
|
||||
assert isinstance(model, T5EncoderDecoderInit)
|
||||
|
||||
# Do not exclude decoder in torch onnx export so that cross can show up.
|
||||
output_cross_only = model.output_cross_only
|
||||
model.output_cross_only = False
|
||||
|
||||
inputs = T5EncoderDecoderInitInputs.create_dummy(
|
||||
model.config,
|
||||
batch_size=2,
|
||||
encode_sequence_length=3,
|
||||
use_decoder_input_ids=use_decoder_input_ids,
|
||||
device=device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
input_list = inputs.to_list()
|
||||
|
||||
present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)
|
||||
|
||||
output_names = ["logits", "encoder_hidden_states", *present_names]
|
||||
|
||||
# Shape of input tensors (sequence_length==1):
|
||||
# input_ids: (batch_size, sequence_length)
|
||||
# encoder_attention_mask: (batch_size, encode_sequence_length)
|
||||
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
|
||||
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
|
||||
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
||||
|
||||
# Shape of output tensors:
|
||||
# logits: (batch_size, sequence_length, vocab_size)
|
||||
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
|
||||
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
||||
|
||||
input_names = ["encoder_input_ids", "encoder_attention_mask"]
|
||||
|
||||
# ONNX exporter might mark dimension like 'present_value_self_1_dim_2' in shape inference.
|
||||
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
|
||||
sequence_length = "1"
|
||||
num_heads = str(model.config.num_heads)
|
||||
hidden_size = str(model.config.d_model)
|
||||
head_size = str(model.config.d_kv)
|
||||
|
||||
dynamic_axes = {
|
||||
"encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"},
|
||||
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
|
||||
"encoder_hidden_states": {
|
||||
0: "batch_size",
|
||||
1: "encode_sequence_length",
|
||||
2: hidden_size,
|
||||
},
|
||||
"logits": {
|
||||
0: "batch_size",
|
||||
1: sequence_length,
|
||||
},
|
||||
}
|
||||
|
||||
if use_decoder_input_ids:
|
||||
input_names.append("decoder_input_ids")
|
||||
dynamic_axes["decoder_input_ids"] = {
|
||||
0: "batch_size",
|
||||
1: sequence_length,
|
||||
}
|
||||
|
||||
for name in present_names:
|
||||
if "cross" in name:
|
||||
dynamic_axes[name] = {
|
||||
0: "batch_size",
|
||||
1: num_heads,
|
||||
2: "encode_sequence_length",
|
||||
3: head_size,
|
||||
}
|
||||
|
||||
else: # self attention past state
|
||||
dynamic_axes[name] = {
|
||||
0: "batch_size",
|
||||
1: num_heads,
|
||||
2: sequence_length,
|
||||
3: head_size,
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
|
||||
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch_onnx_export(
|
||||
model,
|
||||
args=tuple(input_list),
|
||||
f=temp_onnx_model_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=12,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_data_format,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# Restore output_cross_only setting.
|
||||
model.output_cross_only = output_cross_only
|
||||
|
||||
# Workaround as mentioned earlier: change numeric dim_param to dim_value
|
||||
exported_model: onnx.ModelProto = onnx.load(temp_onnx_model_path)
|
||||
for tensor in exported_model.graph.output:
|
||||
for dim_proto in tensor.type.tensor_type.shape.dim:
|
||||
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
|
||||
sequence_length,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
head_size,
|
||||
]:
|
||||
dim_value = int(dim_proto.dim_param)
|
||||
dim_proto.Clear()
|
||||
dim_proto.dim_value = dim_value
|
||||
|
||||
if output_cross_only:
|
||||
# Rewrite onnx graph to only keep present_[key|value]_cross_* outputs.
|
||||
onnx_model = OnnxModel(exported_model)
|
||||
output_name_to_node = onnx_model.output_name_to_node()
|
||||
|
||||
for output in exported_model.graph.output:
|
||||
if "cross" in output.name:
|
||||
assert output.name in output_name_to_node
|
||||
|
||||
transpose_node = output_name_to_node[output.name]
|
||||
assert transpose_node and transpose_node.op_type == "Transpose"
|
||||
|
||||
permutation = OnnxModel.get_node_attribute(transpose_node, "perm")
|
||||
assert isinstance(permutation, list)
|
||||
assert permutation == [0, 2, 1, 3]
|
||||
|
||||
matched_nodes = onnx_model.match_parent_path(
|
||||
transpose_node,
|
||||
["Reshape", "MatMul"],
|
||||
[0, 0],
|
||||
output_name_to_node,
|
||||
)
|
||||
assert matched_nodes is not None
|
||||
|
||||
reshape_node, matmul_node = matched_nodes
|
||||
assert "encoder_hidden_states" in matmul_node.input
|
||||
|
||||
if not onnx_model.get_initializer("cross_reshape_shape"):
|
||||
shape_tensor = onnx.helper.make_tensor(
|
||||
name="cross_reshape_shape",
|
||||
data_type=onnx.TensorProto.INT64,
|
||||
dims=[4],
|
||||
vals=[0, 0, int(num_heads), int(head_size)],
|
||||
raw=False,
|
||||
)
|
||||
onnx_model.add_initializer(shape_tensor)
|
||||
|
||||
reshape_node.input[1] = "cross_reshape_shape"
|
||||
|
||||
cross_outputs = [output.name for output in exported_model.graph.output if "cross" in output.name]
|
||||
onnx_model.prune_graph(cross_outputs, allow_remove_graph_inputs=True)
|
||||
|
||||
OnnxModel.save(
|
||||
exported_model,
|
||||
onnx_model_path,
|
||||
save_as_external_data=use_external_data_format,
|
||||
all_tensors_to_one_file=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
|
||||
"""Run inference of ONNX model."""
|
||||
logger.debug("start onnxruntime_inference")
|
||||
|
||||
ort_inputs = {
|
||||
"encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
|
||||
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
|
||||
}
|
||||
if inputs.decoder_input_ids is not None:
|
||||
ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
|
||||
|
||||
ort_outputs = ort_session.run(None, ort_inputs)
|
||||
return ort_outputs
|
||||
|
||||
@staticmethod
|
||||
def verify_onnx(
|
||||
model: T5EncoderDecoderInit,
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
use_int32_inputs: bool,
|
||||
max_cases: int = 4,
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
ort_inputs = ort_session.get_inputs()
|
||||
use_decoder_input_ids = len(ort_inputs) == 3
|
||||
|
||||
test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)]
|
||||
test_cases_max_diff = []
|
||||
for batch_size, encode_sequence_length in test_cases[:max_cases]:
|
||||
inputs = T5EncoderDecoderInitInputs.create_dummy(
|
||||
model.config,
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
use_decoder_input_ids=use_decoder_input_ids,
|
||||
device=device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
|
||||
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
|
||||
|
||||
# Run inference of PyTorch model
|
||||
input_list = inputs.to_list()
|
||||
torch_outputs = model(*input_list)
|
||||
|
||||
num_decoder_layers = model.config.num_decoder_layers
|
||||
|
||||
if not model.output_cross_only:
|
||||
assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
|
||||
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
|
||||
logger.debug(f"logits max_diff={max_diff}")
|
||||
max_diff_all = max_diff
|
||||
|
||||
assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
|
||||
max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
|
||||
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
|
||||
max_diff_all = max(max_diff_all, max_diff)
|
||||
|
||||
for i in range(2 * num_decoder_layers):
|
||||
max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
|
||||
logger.debug(f"self attention past state {i} max_diff={max_diff}")
|
||||
|
||||
for i in range(2 * num_decoder_layers):
|
||||
max_diff = numpy.amax(
|
||||
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
|
||||
)
|
||||
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
|
||||
max_diff_all = max(max_diff_all, max_diff)
|
||||
else:
|
||||
max_diff_all = -float("inf")
|
||||
for i in range(2 * num_decoder_layers):
|
||||
max_diff = numpy.amax(numpy.abs(torch_outputs[i].cpu().numpy() - ort_outputs[i]))
|
||||
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
|
||||
max_diff_all = max(max_diff_all, max_diff)
|
||||
|
||||
test_cases_max_diff.append(max_diff_all)
|
||||
logger.info(
|
||||
f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"
|
||||
)
|
||||
|
||||
return max(test_cases_max_diff)
|
||||
+302
@@ -0,0 +1,302 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from float16 import float_to_float16_max_diff
|
||||
from onnx_model import OnnxModel
|
||||
from optimizer import optimize_model
|
||||
from t5_decoder import T5Decoder, T5DecoderHelper
|
||||
from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
|
||||
from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
|
||||
PRETRAINED_MT5_MODELS = [
|
||||
"google/mt5-small",
|
||||
"google/mt5-base",
|
||||
"google/mt5-large",
|
||||
"google/mt5-xl",
|
||||
"google/mt5-xxl",
|
||||
]
|
||||
|
||||
|
||||
class T5Helper:
|
||||
@staticmethod
|
||||
def get_onnx_path(
|
||||
output_dir: str,
|
||||
model_name_or_path: str,
|
||||
suffix: str = "",
|
||||
new_folder: bool = False,
|
||||
) -> str:
|
||||
"""Build onnx path
|
||||
|
||||
Args:
|
||||
output_dir (str): output directory
|
||||
model_name_or_path (str): pretrained model name, or path to the model checkpoint
|
||||
suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
|
||||
new_folder (bool, optional): create a new directory for the model. Defaults to False.
|
||||
|
||||
Returns:
|
||||
str: path of onnx model
|
||||
"""
|
||||
model_name = model_name_or_path
|
||||
if os.path.isdir(model_name_or_path):
|
||||
model_name = Path(model_name_or_path).parts[-1]
|
||||
else:
|
||||
model_name.split("/")[-1]
|
||||
|
||||
model_name += suffix
|
||||
|
||||
directory = os.path.join(output_dir, model_name) if new_folder else output_dir
|
||||
return os.path.join(directory, model_name + ".onnx")
|
||||
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_name_or_path: str,
|
||||
cache_dir: str,
|
||||
device: torch.device,
|
||||
model_type: str = "t5",
|
||||
state_dict_path: str = "",
|
||||
encoder_decoder_init: bool = False,
|
||||
) -> dict[str, T5EncoderDecoderInit | T5Decoder]:
|
||||
"""Load model given a pretrained name or path, then build models for ONNX conversion.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): pretrained model name or path
|
||||
cache_dir (str): cache directory
|
||||
device (torch.device): device to run the model
|
||||
model_type (str, optional): model type "t5" or "mt5"
|
||||
state_dict_path(str, optional): state dictionary path
|
||||
encoder_decoder_init (bool, optional): combine encoder and decoder kv cache initialization into one model.
|
||||
Returns:
|
||||
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
||||
"""
|
||||
if model_type == "t5":
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
elif model_type == "mt5":
|
||||
model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
else:
|
||||
raise ValueError("only support mode_type=t5 or mt5")
|
||||
|
||||
if state_dict_path:
|
||||
model.load_state_dict(torch.load(state_dict_path))
|
||||
|
||||
decoder = T5Decoder(model.decoder, model.lm_head, model.config)
|
||||
decoder.eval().to(device)
|
||||
|
||||
encoder = T5EncoderDecoderInit(
|
||||
model.encoder,
|
||||
model.decoder,
|
||||
model.lm_head,
|
||||
model.config,
|
||||
decoder_start_token_id=None,
|
||||
output_cross_only=not encoder_decoder_init,
|
||||
)
|
||||
|
||||
encoder_name = "encoder_decoder_init" if encoder_decoder_init else "encoder"
|
||||
return {encoder_name: encoder, "decoder": decoder}
|
||||
|
||||
@staticmethod
|
||||
def export_onnx(
|
||||
model: T5Decoder | T5EncoderDecoderInit,
|
||||
device: torch.device,
|
||||
onnx_model_path: str,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_decoder_input_ids: bool = True,
|
||||
use_int32_inputs: bool = False,
|
||||
):
|
||||
if isinstance(model, T5EncoderDecoderInit):
|
||||
T5EncoderDecoderInitHelper.export_onnx(
|
||||
model,
|
||||
device,
|
||||
onnx_model_path,
|
||||
use_decoder_input_ids,
|
||||
verbose,
|
||||
use_external_data_format,
|
||||
use_int32_inputs,
|
||||
)
|
||||
else:
|
||||
T5DecoderHelper.export_onnx(
|
||||
model,
|
||||
device,
|
||||
onnx_model_path,
|
||||
verbose,
|
||||
use_external_data_format,
|
||||
use_int32_inputs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_mixed_precision(
|
||||
onnx_model: OnnxModel,
|
||||
op_block_list: list[str] | None = None,
|
||||
force_fp16_logits: bool = False,
|
||||
use_symbolic_shape_infer: bool = True,
|
||||
):
|
||||
"""Convert model to mixed precision.
|
||||
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
|
||||
Args:
|
||||
onnx_model (OnnxModel): optimized ONNX model
|
||||
op_block_list (List[str], optional): operators need to run in fp32.
|
||||
force_fp16_logits (bool, optional): force logits and last MatMul node to be in float16. Defaults to False.
|
||||
use_symbolic_shape_infer (bool, optional): use symbolic shape inference to convert float to float16. Defaults to True.
|
||||
Returns:
|
||||
parameters(dict): a dictionary of parameters used in float16 conversion
|
||||
"""
|
||||
if op_block_list is None:
|
||||
op_block_list = [
|
||||
"SimplifiedLayerNormalization",
|
||||
"SkipSimplifiedLayerNormalization",
|
||||
"Relu",
|
||||
"Add",
|
||||
]
|
||||
|
||||
op_full_set = {node.op_type for node in onnx_model.nodes()}
|
||||
fp32_op_set = set(op_block_list)
|
||||
fp16_op_set = op_full_set.difference(fp32_op_set)
|
||||
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
|
||||
|
||||
# logits is the first output
|
||||
logits_output_name = onnx_model.graph().output[0].name
|
||||
|
||||
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
|
||||
is_weight_fp16_precision = False
|
||||
output_name_to_node = onnx_model.output_name_to_node()
|
||||
assert logits_output_name in output_name_to_node
|
||||
node = output_name_to_node[logits_output_name]
|
||||
last_matmul_node = None
|
||||
if node.op_type == "MatMul":
|
||||
last_matmul_node = node
|
||||
logger.info(f"Found last MatMul node for logits: {node.name}")
|
||||
initializer = None
|
||||
for input in node.input:
|
||||
initializer = onnx_model.get_initializer(input)
|
||||
if initializer is not None:
|
||||
break
|
||||
|
||||
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
|
||||
# we can deduce that the weights are stored in float16 precision.
|
||||
max_diff = float_to_float16_max_diff(initializer)
|
||||
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
|
||||
is_weight_fp16_precision = max_diff < 1e-6
|
||||
else:
|
||||
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
|
||||
|
||||
keep_io_types = []
|
||||
node_block_list = []
|
||||
if (not is_weight_fp16_precision) and (last_matmul_node is not None) and not force_fp16_logits:
|
||||
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
|
||||
keep_io_types = [logits_output_name]
|
||||
node_block_list = [last_matmul_node.name]
|
||||
|
||||
if "Add" not in op_block_list:
|
||||
input_name_to_nodes = onnx_model.input_name_to_nodes()
|
||||
fp32_add = 0
|
||||
changed = True
|
||||
add_nodes = onnx_model.get_nodes_by_op_type("Add")
|
||||
while changed:
|
||||
changed = False
|
||||
for node in add_nodes:
|
||||
if node.name not in node_block_list:
|
||||
parents = onnx_model.get_parents(node, output_name_to_node)
|
||||
children = onnx_model.get_children(node, input_name_to_nodes)
|
||||
blocked_children = [
|
||||
child for child in children if child.op_type in op_block_list or child in node_block_list
|
||||
]
|
||||
blocked_parents = [
|
||||
parent for parent in parents if parent.op_type in op_block_list or parent in node_block_list
|
||||
]
|
||||
# If any child or parent is in fp32, we place the Add node to fp32.
|
||||
if (len(blocked_children) + len(blocked_parents)) > 0:
|
||||
node_block_list.append(node.name)
|
||||
fp32_add += 1
|
||||
changed = True
|
||||
fp16_add = len(add_nodes) - fp32_add
|
||||
logger.info(f"node counter of Add operator: fp32={fp32_add} fp16={fp16_add}")
|
||||
|
||||
logger.info(f"node_block_list: {node_block_list}")
|
||||
|
||||
parameters = {
|
||||
"keep_io_types": keep_io_types,
|
||||
"op_block_list": op_block_list,
|
||||
"node_block_list": node_block_list,
|
||||
"force_fp16_initializers": is_weight_fp16_precision,
|
||||
}
|
||||
|
||||
logger.info(f"auto_mixed_precision parameters: {parameters}")
|
||||
if use_symbolic_shape_infer:
|
||||
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
|
||||
else:
|
||||
# Workaround when symbolic shape inference fails.
|
||||
# Need enable shape_infer_before_optimization in convert_to_onnx.py as well.
|
||||
from float16 import convert_float_to_float16 # noqa: PLC0415
|
||||
|
||||
convert_float_to_float16(
|
||||
onnx_model.model,
|
||||
disable_shape_infer=True,
|
||||
**parameters,
|
||||
)
|
||||
|
||||
return parameters
|
||||
|
||||
@staticmethod
|
||||
def optimize_onnx(
|
||||
onnx_model_path: str,
|
||||
optimized_model_path: str,
|
||||
is_float16: bool,
|
||||
num_attention_heads: int,
|
||||
hidden_size: int,
|
||||
use_external_data_format: bool = False,
|
||||
auto_mixed_precision: bool = True,
|
||||
use_gpu: bool = False,
|
||||
force_fp16_io: bool = False,
|
||||
):
|
||||
"""Optimize ONNX model with an option to convert it to use mixed precision."""
|
||||
|
||||
from fusion_options import FusionOptions # noqa: PLC0415
|
||||
|
||||
optimization_options = None
|
||||
if is_float16:
|
||||
optimization_options = FusionOptions("t5")
|
||||
# SkipLayerNormalization is faster but might bring accuracy drop since it uses fp16 accumulation.
|
||||
optimization_options.enable_skip_layer_norm = not auto_mixed_precision
|
||||
|
||||
m = optimize_model(
|
||||
onnx_model_path,
|
||||
model_type="t5",
|
||||
num_heads=num_attention_heads,
|
||||
hidden_size=hidden_size,
|
||||
opt_level=0,
|
||||
optimization_options=optimization_options,
|
||||
use_gpu=use_gpu,
|
||||
)
|
||||
|
||||
if is_float16:
|
||||
if auto_mixed_precision:
|
||||
T5Helper.auto_mixed_precision(m, force_fp16_logits=force_fp16_io)
|
||||
else:
|
||||
m.convert_model_float32_to_float16(cast_input_output=force_fp16_io)
|
||||
|
||||
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
|
||||
|
||||
@staticmethod
|
||||
def verify_onnx(
|
||||
model: T5Decoder | T5EncoderDecoderInit,
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
use_int32_inputs: bool,
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
if isinstance(model, T5EncoderDecoderInit):
|
||||
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
||||
|
||||
return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
if transformers_dir not in sys.path:
|
||||
sys.path.append(transformers_dir)
|
||||
+610
@@ -0,0 +1,610 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import datetime
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
import whisper
|
||||
from benchmark_helper import measure_memory, setup_logger
|
||||
from onnxruntime_extensions import get_library_path
|
||||
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from tqdm import trange
|
||||
from transformers import AutoModelForSpeechSeq2Seq, WhisperConfig, WhisperProcessor
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_inputs(args: argparse.Namespace):
|
||||
if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}:
|
||||
raise Exception("Unable to auto-detect inputs for provided model")
|
||||
|
||||
def load_via_ffmpeg():
|
||||
audio = whisper.load_audio(args.audio_path)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
return audio
|
||||
|
||||
def load_via_numpy():
|
||||
with open(args.audio_path, "rb") as f:
|
||||
audio = np.asarray(list(f.read()), dtype=np.uint8)
|
||||
audio = np.array([audio])
|
||||
return audio
|
||||
|
||||
inputs = {
|
||||
"max_length": args.max_length,
|
||||
"min_length": args.min_length,
|
||||
"num_beams": args.num_beams,
|
||||
"num_return_sequences": args.num_return_sequences,
|
||||
"length_penalty": args.length_penalty,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
}
|
||||
if args.benchmark_type == "ort":
|
||||
# convert_to_onnx export or ONNX E2E solution created by Olive
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = np.array([v], dtype=np.float32 if "penalty" in k else np.int32)
|
||||
if args.has_decoder_input_ids:
|
||||
inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32)
|
||||
if args.has_logits_processor:
|
||||
inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32)
|
||||
if args.has_temperature:
|
||||
inputs["temperature"] = np.array([args.temperature], dtype=np.float32)
|
||||
|
||||
# Measure time taken to load audio file
|
||||
logger.info(f"Load audio: {args.audio_path}")
|
||||
load_audio_fn = lambda onnx_e2e: load_via_numpy() if onnx_e2e else load_via_ffmpeg() # noqa: E731
|
||||
time_fn(args, load_audio_fn, args.has_audio_stream)
|
||||
audio_data = load_audio_fn(args.has_audio_stream)
|
||||
|
||||
if args.has_audio_stream:
|
||||
# ONNX E2E solution created by Olive
|
||||
inputs["audio_stream"] = audio_data
|
||||
return inputs
|
||||
|
||||
# Measure time taken to get input features
|
||||
logger.info("Feature extraction: ")
|
||||
return_type = "np" if args.benchmark_type == "ort" else "pt"
|
||||
processor_fn = lambda audio: args.processor.feature_extractor( # noqa: E731
|
||||
[audio], return_tensors=return_type, sampling_rate=args.sampling_rate
|
||||
).input_features
|
||||
time_fn(args, processor_fn, audio_data)
|
||||
input_features = processor_fn(audio_data)
|
||||
|
||||
if args.benchmark_type == "ort":
|
||||
# convert_to_onnx export
|
||||
inputs["input_features"] = input_features
|
||||
return inputs
|
||||
|
||||
inputs["inputs"] = input_features.to(
|
||||
dtype=torch.float16 if args.use_fp16 else torch.float32, device=args.target_device
|
||||
)
|
||||
inputs["no_repeat_ngram_size"] = args.no_repeat_ngram_size
|
||||
inputs["early_stopping"] = True
|
||||
inputs["use_cache"] = True
|
||||
|
||||
if args.decoder_input_ids:
|
||||
inputs["forced_decoder_ids"] = args.decoder_input_ids
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def get_model(args: argparse.Namespace):
|
||||
model, sess_options = None, None
|
||||
start_time, end_time = None, None
|
||||
|
||||
# There are multiple sources that the model could come from:
|
||||
# 1) Benchmark Whisper from Hugging Face
|
||||
# 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing)
|
||||
# 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing)
|
||||
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
||||
source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name
|
||||
start_time = time.time()
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
source,
|
||||
torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
|
||||
use_cache=True,
|
||||
).to(args.target_device)
|
||||
end_time = time.time()
|
||||
|
||||
if args.benchmark_type == "hf-pt-compile":
|
||||
model = torch.compile(model)
|
||||
|
||||
elif args.benchmark_type in {"hf-ort", "ort"}:
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_profiling = args.profile
|
||||
sess_options.register_custom_ops_library(get_library_path())
|
||||
if args.verbose:
|
||||
sess_options.log_verbosity_level = 1
|
||||
sess_options.log_severity_level = 1
|
||||
if args.tune:
|
||||
ort.set_default_logger_severity(0)
|
||||
ort.set_default_logger_verbosity(0)
|
||||
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Optimum export
|
||||
provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
|
||||
provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
|
||||
|
||||
start_time = time.time()
|
||||
model = ORTModelForSpeechSeq2Seq.from_pretrained(
|
||||
args.hf_ort_dir_path,
|
||||
provider=provider,
|
||||
provider_options=provider_options,
|
||||
session_options=sess_options,
|
||||
use_io_binding=True, # Avoid memory copy overhead
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
if args.benchmark_type == "ort":
|
||||
# convert_to_onnx.py export
|
||||
logger.info(f"Loading model from {args.ort_model_path}")
|
||||
start_time = time.time()
|
||||
model = ort.InferenceSession(
|
||||
args.ort_model_path,
|
||||
sess_options,
|
||||
providers=[args.execution_provider],
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"Loaded model in {end_time - start_time} s")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def time_fn(args, fn, inputs):
|
||||
warmup_inputs = inputs[0] if type(inputs) is tuple else inputs
|
||||
benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs
|
||||
torch_device = torch.device(args.target_device)
|
||||
|
||||
# Warm up
|
||||
warmup_range = (
|
||||
range(args.warmup_runs)
|
||||
if args.benchmark_type == "ort"
|
||||
else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
|
||||
)
|
||||
|
||||
if args.verbose:
|
||||
outputs = fn(warmup_inputs)
|
||||
logger.info(outputs)
|
||||
|
||||
for _ in warmup_range:
|
||||
fn(warmup_inputs)
|
||||
|
||||
# Benchmark
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize(torch_device)
|
||||
start_time = time.time()
|
||||
|
||||
bench_range = (
|
||||
range(args.num_runs)
|
||||
if args.benchmark_type == "ort"
|
||||
else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
|
||||
)
|
||||
for _ in bench_range:
|
||||
fn(benchmark_inputs)
|
||||
|
||||
if args.device != "cpu":
|
||||
torch.cuda.synchronize(torch_device)
|
||||
end_time = time.time()
|
||||
|
||||
# Newline print after trange in order to print metrics on new lines without progress bar on same line
|
||||
if args.benchmark_type != "ort":
|
||||
logger.info("")
|
||||
|
||||
batch_size = 1
|
||||
latency = (end_time - start_time) / args.num_runs
|
||||
throughput = batch_size / latency
|
||||
|
||||
logger.info(f"Latency: {latency} s")
|
||||
logger.info(f"Throughput: {throughput} qps")
|
||||
return
|
||||
|
||||
|
||||
def profile_fn(args, fn, inputs, inputs_type):
|
||||
# Filename prefix format:
|
||||
# "<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
|
||||
prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
|
||||
filename = None
|
||||
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
|
||||
# Profile PyTorch kernels
|
||||
with profile( # noqa: SIM117
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
|
||||
) as prof:
|
||||
with record_function("model_inference"):
|
||||
fn(inputs)
|
||||
prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
|
||||
|
||||
filename = os.path.join(args.log_folder, f"{prefix}.log")
|
||||
with open(filename, "w") as f:
|
||||
f.write(prof_data)
|
||||
|
||||
else:
|
||||
# Profile ORT kernels
|
||||
fn(inputs)
|
||||
|
||||
# Set new log name for ORT profile log generated
|
||||
filename = f"{prefix}.json"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def measure_fn(args, fn, inputs):
|
||||
# Measure CPU usage
|
||||
pid = os.getpid()
|
||||
process = psutil.Process(pid)
|
||||
process.cpu_percent(interval=0.1)
|
||||
|
||||
fn(inputs)
|
||||
logger.info(f"CPU usage: {process.cpu_percent(interval=None)}%")
|
||||
|
||||
# Measure memory usage
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs), monitor_type=args.monitor_type)
|
||||
|
||||
# Flush output so memory usage is printed
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def run_hf_inference(args, inputs, model):
|
||||
# Inference steps to measure
|
||||
def get_pred_ids(inputs):
|
||||
# Inference pass with predicted token ids generation
|
||||
predicted_ids = model.generate(**inputs)
|
||||
return predicted_ids
|
||||
|
||||
def gen_and_dec(inputs):
|
||||
# Inference pass with generation and decoding
|
||||
predicted_ids = get_pred_ids(inputs)
|
||||
transcription = []
|
||||
for _ in range(args.num_return_sequences):
|
||||
transcription.append(args.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
|
||||
return predicted_ids, transcription
|
||||
|
||||
# Examples of other inference steps that can be measured:
|
||||
# To use, uncomment the function and assign it to `generate_fn`
|
||||
|
||||
# def get_logits(inputs):
|
||||
# # Inference pass without decoding
|
||||
# outputs = model(**inputs)
|
||||
# return outputs
|
||||
|
||||
generate_fn = gen_and_dec
|
||||
|
||||
if args.benchmark_type == "hf-pt-compile":
|
||||
# Run forward pass once with each set of inputs to process through Dynamo
|
||||
generate_fn(inputs)
|
||||
|
||||
if args.profile:
|
||||
new_logname = profile_fn(args, generate_fn, inputs, "gen-and-dec")
|
||||
if args.benchmark_type == "hf-ort":
|
||||
# Rename log files per model component and turn profiling off to stop appending to log
|
||||
new_prefix = new_logname[: -len(".json")]
|
||||
|
||||
old_logname = model.encoder.session.end_profiling()
|
||||
new_logname = new_prefix + "-encoder.json"
|
||||
if os.path.isfile(old_logname):
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
old_logname = model.decoder.session.end_profiling()
|
||||
new_logname = new_prefix + "-decoder.json"
|
||||
if os.path.isfile(old_logname):
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
old_logname = model.decoder_with_past.session.end_profiling()
|
||||
new_logname = new_prefix + "-decoder-with-past.json"
|
||||
if os.path.isfile(old_logname):
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
return
|
||||
|
||||
# PyTorch evaluations
|
||||
logger.info("\nEvaluating PyTorch...")
|
||||
time_fn(args, generate_fn, inputs)
|
||||
predicted_ids, transcription = generate_fn(inputs)
|
||||
logger.info(f"Generated token length: {len(predicted_ids[0])} tokens")
|
||||
logger.info(f"Transcription: {transcription[0]}")
|
||||
measure_fn(args, generate_fn, inputs)
|
||||
|
||||
|
||||
def run_ort_inference(args, inputs, model):
|
||||
def prepare_ort_inputs(inputs, warmup=False):
|
||||
# Check that all model inputs will be provided
|
||||
model_inputs = {model_input.name for model_input in model.get_inputs()}
|
||||
user_inputs = set(inputs.keys())
|
||||
missing_inputs = model_inputs - user_inputs
|
||||
if len(missing_inputs):
|
||||
logger.error(f"The following model inputs are missing: {missing_inputs}")
|
||||
raise Exception("There are missing inputs to the model. Please add them and try again.")
|
||||
|
||||
if warmup and args.tune:
|
||||
inputs["min_length"] = inputs["max_length"]
|
||||
|
||||
# Remove unnecessary inputs from model inputs
|
||||
unnecessary_inputs = user_inputs - model_inputs
|
||||
if len(unnecessary_inputs):
|
||||
for unnecessary_input in unnecessary_inputs:
|
||||
logger.info(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
|
||||
del inputs[unnecessary_input]
|
||||
|
||||
# Add IO bindings for non-CPU execution providers
|
||||
if args.device != "cpu":
|
||||
io_binding = model.io_binding()
|
||||
for k, v in inputs.items():
|
||||
io_binding.bind_cpu_input(k, v)
|
||||
for output in model.get_outputs():
|
||||
io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id)
|
||||
return io_binding
|
||||
|
||||
return inputs
|
||||
|
||||
def with_io_binding(io_binding):
|
||||
# Inference pass with IO binding
|
||||
model.run_with_iobinding(io_binding)
|
||||
return io_binding
|
||||
|
||||
def without_io_binding(inputs):
|
||||
# Inference pass without IO binding
|
||||
outputs = model.run(None, inputs)
|
||||
return outputs
|
||||
|
||||
def handle_output(output):
|
||||
if args.eos_token_id in output:
|
||||
first_end = np.where(output == args.eos_token_id)[0][0]
|
||||
return output[: first_end + 1]
|
||||
|
||||
return output
|
||||
|
||||
generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
|
||||
ort_inputs = prepare_ort_inputs(inputs)
|
||||
|
||||
if args.profile:
|
||||
new_logname = profile_fn(args, generate_fn, ort_inputs, "e2e")
|
||||
|
||||
# Turn profiling off to stop appending to log file
|
||||
old_logname = model.end_profiling()
|
||||
logger.warning(f"Renaming {old_logname} to {new_logname}")
|
||||
os.rename(old_logname, os.path.join(args.log_folder, new_logname))
|
||||
|
||||
return
|
||||
|
||||
# ORT evaluation
|
||||
logger.info("\nEvaluating ONNX Runtime...")
|
||||
ort_evaluate_inputs = ort_inputs
|
||||
if args.tune:
|
||||
ort_warmup_inputs = prepare_ort_inputs(inputs, warmup=True)
|
||||
ort_evaluate_inputs = (ort_warmup_inputs, ort_inputs)
|
||||
|
||||
time_fn(args, generate_fn, ort_evaluate_inputs)
|
||||
ort_outputs = generate_fn(ort_inputs)
|
||||
if args.device != "cpu":
|
||||
ort_outputs = ort_outputs.copy_outputs_to_cpu()
|
||||
ort_outputs = ort_outputs[0]
|
||||
|
||||
if args.has_audio_stream:
|
||||
# ONNX E2E model from Olive produces transcribed output
|
||||
logger.info(f"Transcription: {ort_outputs[0][0]}")
|
||||
else:
|
||||
# convert_to_onnx model produces generated ids
|
||||
actual_output = handle_output(ort_outputs[0][0])
|
||||
logger.info(f"Generated token length: {len(actual_output)} tokens")
|
||||
transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0]
|
||||
# print to stdout as the output for comparison
|
||||
print(f"{transcription}")
|
||||
|
||||
measure_fn(args, generate_fn, ort_inputs)
|
||||
|
||||
|
||||
def run_inference(args, inputs, model):
|
||||
if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
|
||||
run_hf_inference(args, inputs, model)
|
||||
elif args.benchmark_type == "ort":
|
||||
run_ort_inference(args, inputs, model)
|
||||
else:
|
||||
raise Exception(f"Cannot recognize {args.benchmark_type}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-bt",
|
||||
"--benchmark-type",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Hugging Face name of model (e.g. 'openai/whisper-large-v2')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
type=str,
|
||||
required=True,
|
||||
default="fp32",
|
||||
choices=["int8", "fp16", "fp32"],
|
||||
help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-pt-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-ort-dir-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ort-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to ONNX model",
|
||||
)
|
||||
|
||||
# Args for running and evaluating the model
|
||||
parser.add_argument("-a", "--audio-path", type=str, required=True, help="Path to audio file for E2E evaluation")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
)
|
||||
parser.add_argument("-id", "--device-id", type=int, default=0)
|
||||
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
|
||||
parser.add_argument("-n", "--num-runs", type=int, default=10)
|
||||
parser.add_argument("--seed", type=int, default=2)
|
||||
|
||||
# Optional args:
|
||||
parser.add_argument("--sampling-rate", type=int, default=16000, help="Sampling rate for audio (in Hz)")
|
||||
|
||||
# Args for decoding logic
|
||||
# Required args:
|
||||
parser.add_argument("--max-length", type=int, default=448)
|
||||
parser.add_argument("--min-length", type=int, default=0)
|
||||
parser.add_argument("--num-beams", type=int, default=1)
|
||||
parser.add_argument("--num-return-sequences", type=int, default=1)
|
||||
parser.add_argument("--length-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||
parser.add_argument("--no-repeat-ngram-size", type=int, default=3)
|
||||
|
||||
# Optional args for E2E solution:
|
||||
parser.add_argument(
|
||||
"--decoder-input-ids",
|
||||
type=str,
|
||||
default="[]",
|
||||
help="The forced decoder ids for generation. Format is [start token, timestamp token, language token, task token]. Default is [start token]. See `decoder_input_ids` in https://github.com/microsoft/Olive/tree/main/examples/whisper for details.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logits-processor",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Whether to use timestamps logits processor or not (0 for false, 1 for true).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Temperature value for generation.",
|
||||
)
|
||||
|
||||
# Args for accessing detailed info
|
||||
parser.add_argument("--profile", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
|
||||
)
|
||||
parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
|
||||
parser.add_argument("--verbose", default=False, action="store_true")
|
||||
parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
|
||||
parser.add_argument(
|
||||
"--tune",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Only used by ROCm EP, enable TunableOp tuning to select fastest kernel",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set seed properties
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
args.monitor_type = args.device
|
||||
# Set runtime properties
|
||||
if "ort" in args.benchmark_type:
|
||||
args.execution_provider = f"{args.device.upper()}ExecutionProvider"
|
||||
if args.execution_provider == "CUDAExecutionProvider":
|
||||
args.execution_provider = (args.execution_provider, {"device_id": args.device_id})
|
||||
elif args.execution_provider == "ROCMExecutionProvider":
|
||||
args.execution_provider = (
|
||||
args.execution_provider,
|
||||
{
|
||||
"device_id": args.device_id,
|
||||
"tunable_op_enable": 1,
|
||||
"tunable_op_tuning_enable": 1 if args.tune else 0,
|
||||
},
|
||||
)
|
||||
args.device = "cuda"
|
||||
|
||||
# Check that model paths have been specified for any benchmarking with ORT
|
||||
if args.benchmark_type == "hf-ort":
|
||||
assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
|
||||
if args.benchmark_type == "ort":
|
||||
assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
|
||||
|
||||
# Convert decoder_input_ids string to list of ids
|
||||
# (e.g. "[1, 50257]" for Hugging Face or "[50257]" for ORT)
|
||||
args.decoder_input_ids = ast.literal_eval(args.decoder_input_ids)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
config = WhisperConfig.from_pretrained(args.model_name)
|
||||
processor = WhisperProcessor.from_pretrained(args.model_name)
|
||||
target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device
|
||||
use_fp16 = args.precision == "fp16"
|
||||
|
||||
setattr(args, "processor", processor) # noqa: B010
|
||||
setattr(args, "target_device", target_device) # noqa: B010
|
||||
setattr(args, "use_fp16", use_fp16) # noqa: B010
|
||||
setattr(args, "has_audio_stream", False) # noqa: B010
|
||||
setattr(args, "eos_token_id", config.eos_token_id) # noqa: B010
|
||||
|
||||
logger.info(f"Forced decoder prompt ids: {args.decoder_input_ids}")
|
||||
|
||||
# Measure cost to transcribe audio
|
||||
model = get_model(args)
|
||||
if args.benchmark_type == "ort":
|
||||
# Check for optional inputs that could have been added during export
|
||||
ort_model_inputs = {model_input.name for model_input in model.get_inputs()}
|
||||
args.has_audio_stream = "audio_stream" in ort_model_inputs
|
||||
setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010
|
||||
setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010
|
||||
setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010
|
||||
|
||||
if args.decoder_input_ids == []:
|
||||
args.decoder_input_ids = [config.decoder_start_token_id]
|
||||
|
||||
inputs = get_inputs(args)
|
||||
run_inference(args, inputs, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+526
@@ -0,0 +1,526 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
from benchmark_helper import setup_logger
|
||||
from metrics import BenchmarkRecord
|
||||
from transformers import WhisperConfig, WhisperProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--audio-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to folder of audio files for E2E evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--language",
|
||||
default=None,
|
||||
help="Language of audio file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--task",
|
||||
default=None,
|
||||
choices=["transcribe", "translate"],
|
||||
help="Task to complete",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--warmup-runs",
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=10,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-pt-eager",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Benchmark in PyTorch without `torch.compile`",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-pt-compile",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Benchmark in PyTorch with `torch.compile`",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-ort-dir-path",
|
||||
type=str,
|
||||
help="Path to folder containing ONNX models for Optimum + ORT benchmarking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ort-model-path",
|
||||
type=str,
|
||||
help="Path to ONNX model for ORT benchmarking",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name in Hugging Face (e.g. openai/whisper-large-v2)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["int8", "fp16", "fp32"],
|
||||
help="Precision to run model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["cpu", "cuda", "rocm"],
|
||||
help="Device to benchmark models",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device-id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="GPU device ID",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Print detailed logs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of mins to attempt the benchmark before moving on",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to folder to save logs and results",
|
||||
)
|
||||
|
||||
parser.add_argument("--tune", default=False, action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setattr(args, "model_size", args.model_name.split("/")[-1].replace(".", "-")) # noqa: B010
|
||||
log_folder_name = f"./{args.model_size}-{args.precision}"
|
||||
if not args.log_folder:
|
||||
args.log_folder = log_folder_name
|
||||
os.makedirs(args.log_folder, exist_ok=True)
|
||||
|
||||
# Convert timeout value to secs
|
||||
args.timeout *= 60
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def process_log_file(device_id, log_file, base_results):
|
||||
entries = []
|
||||
|
||||
# Detect steps in speech pipeline
|
||||
step = None
|
||||
load_audio_pattern = "Load audio: "
|
||||
feat_ext_pattern = "Feature extraction: "
|
||||
pytorch_pattern = "Evaluating PyTorch..."
|
||||
onnxruntime_pattern = "Evaluating ONNX Runtime..."
|
||||
|
||||
load_audio_latency_s, load_audio_throughput_s = None, None
|
||||
feat_ext_latency_s, feat_ext_throughput_s = None, None
|
||||
token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None
|
||||
throughput, memory = None, None
|
||||
|
||||
# Detect metrics
|
||||
latency_pattern = "Latency: "
|
||||
throughput_pattern = "Throughput: "
|
||||
token_length_pattern = "Generated token length: "
|
||||
memory_pattern = "peak="
|
||||
|
||||
with open(log_file) as f:
|
||||
for input_line in f:
|
||||
line = input_line.replace("\n", "")
|
||||
|
||||
# Get step in speech recognition pipeline
|
||||
if load_audio_pattern in line:
|
||||
step = "load-audio"
|
||||
elif feat_ext_pattern in line:
|
||||
step = "feature-extraction"
|
||||
elif pytorch_pattern in line or onnxruntime_pattern in line:
|
||||
step = "process"
|
||||
|
||||
# Check metrics
|
||||
if latency_pattern in line:
|
||||
latency_s = float(line[len(latency_pattern) : line.rfind(" ")])
|
||||
elif throughput_pattern in line:
|
||||
throughput = float(line[len(throughput_pattern) : line.rfind(" ")])
|
||||
if step == "load-audio":
|
||||
load_audio_latency_s, load_audio_throughput_s = latency_s, throughput
|
||||
step = None
|
||||
if step == "feature-extraction":
|
||||
feat_ext_latency_s, feat_ext_throughput_s = latency_s, throughput
|
||||
step = None
|
||||
elif token_length_pattern in line:
|
||||
token_length = int(line[len(token_length_pattern) : line.rfind(" ")])
|
||||
per_token_latency_s = latency_s / token_length
|
||||
per_token_latency_ms = per_token_latency_s * 1000
|
||||
elif memory_pattern in line:
|
||||
if "CPU" in line:
|
||||
# Example format for log entry:
|
||||
# CPU memory usage: before=1000.0 MB, peak=2000.0 MB
|
||||
memory = float(line[line.rfind("=") + 1 : line.rfind(" MB")]) / 1000
|
||||
else:
|
||||
# Example format for log entry:
|
||||
# GPU memory usage: before=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1638.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}, peak=[{'device_id': 0, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 1780.875}, {'device_id': 1, 'name': 'Tesla V100-PCIE-16GB', 'max_used_MB': 236.875}]
|
||||
peak = line[line.find(memory_pattern) + len(memory_pattern) :].replace("'", '"')
|
||||
usage = json.loads(peak)[device_id]["max_used_MB"]
|
||||
memory = float(usage) / 1000
|
||||
|
||||
# Calculate real-time factor (RTF):
|
||||
# RTF = total latency / audio duration
|
||||
total_latency = (
|
||||
(load_audio_latency_s if load_audio_latency_s else 0)
|
||||
+ (feat_ext_latency_s if feat_ext_latency_s else 0)
|
||||
+ (latency_s if latency_s else 0)
|
||||
)
|
||||
audio_duration = base_results[-1]
|
||||
rtf = (total_latency / audio_duration) if audio_duration else -1
|
||||
logger.info(f"Total latency: {total_latency} s")
|
||||
logger.info(f"Audio duration: {audio_duration} s")
|
||||
logger.info(f"Real-time factor: {rtf}")
|
||||
|
||||
# Append log entry to list of entries
|
||||
entry = base_results + [ # noqa: RUF005
|
||||
token_length,
|
||||
load_audio_latency_s,
|
||||
load_audio_throughput_s,
|
||||
feat_ext_latency_s if feat_ext_latency_s else -1,
|
||||
feat_ext_throughput_s if feat_ext_throughput_s else -1,
|
||||
latency_s,
|
||||
per_token_latency_ms,
|
||||
throughput,
|
||||
memory,
|
||||
rtf,
|
||||
]
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def save_results(results, filename):
|
||||
import pandas as pd # noqa: PLC0415
|
||||
|
||||
df = pd.DataFrame(
|
||||
results,
|
||||
columns=[
|
||||
"Warmup Runs",
|
||||
"Measured Runs",
|
||||
"Model Name",
|
||||
"Engine",
|
||||
"Precision",
|
||||
"Device",
|
||||
"Audio File",
|
||||
"Duration (s)",
|
||||
"Token Length",
|
||||
"Load Audio Latency (s)",
|
||||
"Load Audio Throughput (qps)",
|
||||
"Feature Extractor Latency (s)",
|
||||
"Feature Extractor Throughput (qps)",
|
||||
"Latency (s)",
|
||||
"Per Token Latency (ms/token)",
|
||||
"Throughput (qps)",
|
||||
"Memory (GB)",
|
||||
"Real Time Factor (RTF)",
|
||||
],
|
||||
)
|
||||
|
||||
# Set column types
|
||||
df["Warmup Runs"] = df["Warmup Runs"].astype("int")
|
||||
df["Measured Runs"] = df["Measured Runs"].astype("int")
|
||||
df["Duration (s)"] = df["Duration (s)"].astype("float")
|
||||
df["Token Length"] = df["Token Length"].astype("int")
|
||||
df["Load Audio Latency (s)"] = df["Load Audio Latency (s)"].astype("float")
|
||||
df["Load Audio Throughput (qps)"] = df["Load Audio Throughput (qps)"].astype("float")
|
||||
df["Feature Extractor Latency (s)"] = df["Feature Extractor Latency (s)"].astype("float")
|
||||
df["Feature Extractor Throughput (qps)"] = df["Feature Extractor Throughput (qps)"].astype("float")
|
||||
df["Latency (s)"] = df["Latency (s)"].astype("float")
|
||||
df["Per Token Latency (ms/token)"] = df["Per Token Latency (ms/token)"].astype("float")
|
||||
df["Throughput (qps)"] = df["Throughput (qps)"].astype("float")
|
||||
df["Memory (GB)"] = df["Memory (GB)"].astype("float")
|
||||
df["Real Time Factor (RTF)"] = df["Real Time Factor (RTF)"].astype("float")
|
||||
|
||||
# get package name and version
|
||||
import pkg_resources # noqa: PLC0415
|
||||
|
||||
installed_packages = pkg_resources.working_set
|
||||
installed_packages_list = sorted(
|
||||
[f"{i.key}=={i.version}" for i in installed_packages if i.key in ["onnxruntime", "onnxruntime-gpu"]]
|
||||
)
|
||||
ort_pkg_name = ""
|
||||
ort_pkg_version = ""
|
||||
if installed_packages_list:
|
||||
ort_pkg_name = installed_packages_list[0].split("==")[0]
|
||||
ort_pkg_version = installed_packages_list[0].split("==")[1]
|
||||
|
||||
# Save results to csv with standard format
|
||||
records = []
|
||||
for _, row in df.iterrows():
|
||||
if row["Engine"] == "onnxruntime":
|
||||
record = BenchmarkRecord(
|
||||
row["Model Name"], row["Precision"], row["Engine"], row["Device"], ort_pkg_name, ort_pkg_version
|
||||
)
|
||||
else:
|
||||
record = BenchmarkRecord(
|
||||
row["Model Name"], row["Precision"], row["Engine"], row["Device"], torch.__name__, torch.__version__
|
||||
)
|
||||
record.config.customized["audio_file"] = row["Audio File"]
|
||||
record.config.warmup_runs = row["Warmup Runs"]
|
||||
record.config.measured_runs = row["Measured Runs"]
|
||||
|
||||
record.metrics.customized["duration"] = row["Duration (s)"]
|
||||
record.metrics.customized["token_length"] = row["Token Length"]
|
||||
record.metrics.customized["load_audio_latency"] = row["Load Audio Latency (s)"]
|
||||
record.metrics.customized["load_audio_throughput"] = row["Load Audio Throughput (qps)"]
|
||||
record.metrics.customized["feature_extractor_latency_s"] = row["Feature Extractor Latency (s)"]
|
||||
record.metrics.customized["feature_extractor_throughput_qps"] = row["Feature Extractor Throughput (qps)"]
|
||||
record.metrics.customized["per_token_latency_ms"] = row["Per Token Latency (ms/token)"]
|
||||
record.metrics.customized["rtf"] = row["Real Time Factor (RTF)"]
|
||||
|
||||
record.metrics.latency_ms_mean = row["Latency (s)"] * 1000
|
||||
record.metrics.throughput_qps = row["Throughput (qps)"]
|
||||
record.metrics.max_memory_usage_GB = row["Memory (GB)"]
|
||||
|
||||
records.append(record)
|
||||
|
||||
BenchmarkRecord.save_as_csv(filename, records)
|
||||
BenchmarkRecord.save_as_json(filename.replace(".csv", ".json"), records)
|
||||
logger.info(f"Results saved in {filename}!")
|
||||
|
||||
|
||||
def benchmark(args, benchmark_cmd, engine, audio_file, duration):
|
||||
log_filename = f"{engine}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.log"
|
||||
log_path = os.path.join(args.log_folder, log_filename)
|
||||
with open(log_path, "w") as log_file:
|
||||
process = subprocess.Popen(benchmark_cmd, stdout=log_file, stderr=log_file)
|
||||
try:
|
||||
process.wait(args.timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
|
||||
# Create entries for csv
|
||||
logger.info("Gathering data from log files...")
|
||||
base_results = [
|
||||
args.warmup_runs,
|
||||
args.num_runs,
|
||||
args.model_name,
|
||||
engine,
|
||||
args.precision,
|
||||
args.device,
|
||||
audio_file,
|
||||
duration,
|
||||
]
|
||||
results = process_log_file(args.device_id, log_path, base_results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
setup_logger(args.verbose)
|
||||
logger.info(args.__dict__)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
config = WhisperConfig.from_pretrained(args.model_name)
|
||||
processor = WhisperProcessor.from_pretrained(args.model_name)
|
||||
|
||||
# Calculate forced decoder input ids
|
||||
hf_forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
|
||||
ort_forced_decoder_ids = [config.decoder_start_token_id] + [token_id[1] for token_id in hf_forced_decoder_ids]
|
||||
hf_decoder_input_ids_cmd = (
|
||||
["--decoder-input-ids", str(hf_forced_decoder_ids)] if args.language and args.task else []
|
||||
)
|
||||
ort_decoder_input_ids_cmd = (
|
||||
["--decoder-input-ids", str(ort_forced_decoder_ids)] if args.language and args.task else []
|
||||
)
|
||||
ort_tune_cmd = ["--tune"] if args.tune else []
|
||||
|
||||
all_results = []
|
||||
for audio_file in os.listdir(args.audio_path):
|
||||
audio_path = os.path.join(args.audio_path, audio_file)
|
||||
try:
|
||||
duration = librosa.get_duration(path=audio_path)
|
||||
except Exception as e:
|
||||
duration = -1
|
||||
logger.warning(f"An error occurred while trying to calculate the audio duration: {e}", exc_info=True)
|
||||
logger.warning(
|
||||
f"If you get an error that says:\n\tsoundfile.LibsndfileError: Error opening '{audio_file}': File contains data in an unknown format.\nyou may not have installed `ffmpeg` in addition to installing `librosa`."
|
||||
)
|
||||
logger.info(f"Testing {audio_path}...")
|
||||
|
||||
# Benchmark PyTorch without torch.compile
|
||||
if args.hf_pt_eager:
|
||||
benchmark_cmd = [ # noqa: RUF005
|
||||
"python",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"hf-pt-eager",
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
] + hf_decoder_input_ids_cmd
|
||||
logger.info("Benchmark PyTorch without torch.compile")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark PyTorch with torch.compile
|
||||
if args.hf_pt_compile:
|
||||
benchmark_cmd = [ # noqa: RUF005
|
||||
"python",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"hf-pt-compile",
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
] + hf_decoder_input_ids_cmd
|
||||
logger.info("Benchmark PyTorch with torch.compile")
|
||||
results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark Optimum + ONNX Runtime
|
||||
if args.hf_ort_dir_path:
|
||||
benchmark_cmd = [ # noqa: RUF005
|
||||
"python",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"hf-ort",
|
||||
"--hf-ort-dir-path",
|
||||
args.hf_ort_dir_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
] + hf_decoder_input_ids_cmd
|
||||
logger.info("Benchmark Optimum + ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
||||
# Benchmark ONNX Runtime
|
||||
if args.ort_model_path:
|
||||
benchmark_cmd = (
|
||||
[ # noqa: RUF005
|
||||
"python",
|
||||
"-m",
|
||||
"models.whisper.benchmark",
|
||||
"--audio-path",
|
||||
audio_path,
|
||||
"--benchmark-type",
|
||||
"ort",
|
||||
"--ort-model-path",
|
||||
args.ort_model_path,
|
||||
"--model-name",
|
||||
args.model_name,
|
||||
"--precision",
|
||||
args.precision,
|
||||
"--device",
|
||||
args.device,
|
||||
"--device-id",
|
||||
str(args.device_id),
|
||||
"--warmup-runs",
|
||||
str(args.warmup_runs),
|
||||
"--num-runs",
|
||||
str(args.num_runs),
|
||||
"--log-folder",
|
||||
args.log_folder,
|
||||
]
|
||||
+ ort_decoder_input_ids_cmd
|
||||
+ ort_tune_cmd
|
||||
)
|
||||
logger.info("Benchmark ONNX Runtime")
|
||||
results = benchmark(args, benchmark_cmd, "onnxruntime", audio_file, duration)
|
||||
all_results.extend(results)
|
||||
|
||||
csv_file = f"{args.model_size}-{args.precision}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
|
||||
save_results(all_results, os.path.join(args.log_folder, csv_file))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+573
@@ -0,0 +1,573 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
|
||||
from whisper_chain import chain_model
|
||||
from whisper_encoder import WhisperEncoder
|
||||
from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper
|
||||
|
||||
from onnxruntime import quantization
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
PROVIDERS = {
|
||||
"cpu": "CPUExecutionProvider",
|
||||
"cuda": "CUDAExecutionProvider",
|
||||
"rocm": "ROCMExecutionProvider",
|
||||
}
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
conversion_args = parser.add_argument_group("Conversion Process Args")
|
||||
optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)")
|
||||
optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)")
|
||||
quant_args = parser.add_argument_group("INT8 Quantization Args")
|
||||
|
||||
#################################
|
||||
# Conversion options for Whisper
|
||||
#################################
|
||||
|
||||
conversion_args.add_argument(
|
||||
"-m",
|
||||
"--model_name_or_path",
|
||||
required=False,
|
||||
default=PRETRAINED_WHISPER_MODELS[0],
|
||||
type=str,
|
||||
help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS),
|
||||
)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--model_impl",
|
||||
required=False,
|
||||
default="hf",
|
||||
choices=["hf", "openai"],
|
||||
type=str,
|
||||
help="Select implementation for export of encoder and decoder subgraphs",
|
||||
)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--cache_dir",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "cache_models"),
|
||||
help="Directory to cache pre-trained models",
|
||||
)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--output",
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join(".", "onnx_models"),
|
||||
help="Output directory",
|
||||
)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"-o",
|
||||
"--optimize_onnx",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use optimizer.py to optimize onnx model",
|
||||
)
|
||||
conversion_args.set_defaults(optimize_onnx=False)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--use_gpu",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use GPU for model inference",
|
||||
)
|
||||
conversion_args.set_defaults(use_gpu=False)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"-p",
|
||||
"--precision",
|
||||
required=False,
|
||||
type=Precision,
|
||||
default=Precision.FLOAT32,
|
||||
choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
|
||||
help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization",
|
||||
)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--use_int64_inputs",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use int64 instead of int32 for input_ids and attention_mask.",
|
||||
)
|
||||
conversion_args.set_defaults(use_int64_inputs=False)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"-r",
|
||||
"--provider",
|
||||
required=False,
|
||||
type=str,
|
||||
default="cpu",
|
||||
choices=list(PROVIDERS.keys()),
|
||||
help="Provider to benchmark. Default is CPUExecutionProvider.",
|
||||
)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--verbose",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
conversion_args.set_defaults(verbose=False)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"-e",
|
||||
"--use_external_data_format",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.",
|
||||
)
|
||||
conversion_args.set_defaults(use_external_data_format=False)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"-w",
|
||||
"--overwrite",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Overwrite existing ONNX model",
|
||||
)
|
||||
conversion_args.set_defaults(overwrite=False)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--separate_encoder_and_decoder_init",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.",
|
||||
)
|
||||
conversion_args.set_defaults(separate_encoder_and_decoder_init=False)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--no_beam_search_op",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.",
|
||||
)
|
||||
conversion_args.set_defaults(no_beam_search_op=False)
|
||||
|
||||
conversion_args.add_argument(
|
||||
"--use_decoder_masked_mha",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use DecoderMaskedMultiHeadAttention kernel for improved performance. This is currently an experimental feature.",
|
||||
)
|
||||
conversion_args.set_defaults(use_decoder_masked_mha=False)
|
||||
|
||||
#############################################################
|
||||
# Optional inputs for Whisper
|
||||
# (listed below in the order that WhisperBeamSearch expects)
|
||||
#############################################################
|
||||
|
||||
optional_inputs.add_argument(
|
||||
"-v",
|
||||
"--use_vocab_mask",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use vocab_mask as an extra graph input to enable specific logits processing",
|
||||
)
|
||||
optional_inputs.set_defaults(use_vocab_mask=False)
|
||||
|
||||
optional_inputs.add_argument(
|
||||
"-u",
|
||||
"--use_prefix_vocab_mask",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
|
||||
)
|
||||
optional_inputs.set_defaults(use_prefix_vocab_mask=False)
|
||||
|
||||
optional_inputs.add_argument(
|
||||
"-f",
|
||||
"--use_forced_decoder_ids",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use decoder_input_ids as an extra graph input to the beam search op",
|
||||
)
|
||||
optional_inputs.set_defaults(use_forced_decoder_ids=False)
|
||||
|
||||
optional_inputs.add_argument(
|
||||
"-l",
|
||||
"--use_logits_processor",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use logits_processor as an extra graph input to enable specific logits processing",
|
||||
)
|
||||
optional_inputs.set_defaults(use_specific_logits_processor=False)
|
||||
|
||||
optional_inputs.add_argument(
|
||||
"--collect_cross_qk",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Beam search model collect stacked cross QK.",
|
||||
)
|
||||
optional_inputs.set_defaults(collect_cross_qk=False)
|
||||
|
||||
optional_inputs.add_argument(
|
||||
"--extra_decoding_ids",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
|
||||
)
|
||||
optional_inputs.set_defaults(extra_decoding_ids=False)
|
||||
|
||||
optional_inputs.add_argument(
|
||||
"-t",
|
||||
"--use_temperature",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use temperature as an extra graph input for the WhisperBeamSearch op",
|
||||
)
|
||||
optional_inputs.set_defaults(use_temperature=False)
|
||||
|
||||
optional_inputs.add_argument(
|
||||
"--no_repeat_ngram_size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="default to 0",
|
||||
)
|
||||
|
||||
#############################################################
|
||||
# Optional outputs for Whisper
|
||||
# (listed below in the order that WhisperBeamSearch expects)
|
||||
#############################################################
|
||||
|
||||
optional_outputs.add_argument(
|
||||
"--output_sequence_scores",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Beam search model output scores for each generated sequence.",
|
||||
)
|
||||
optional_outputs.set_defaults(output_sequence_scores=False)
|
||||
|
||||
optional_outputs.add_argument(
|
||||
"--output_scores",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Beam search model output scores over vocab per generated token.",
|
||||
)
|
||||
optional_outputs.set_defaults(output_scores=False)
|
||||
|
||||
optional_outputs.add_argument(
|
||||
"--output_cross_qk",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Beam search model output collected qk as output. Also hint collect_cross_qk",
|
||||
)
|
||||
optional_outputs.set_defaults(output_cross_qk=False)
|
||||
|
||||
optional_outputs.add_argument(
|
||||
"--cross_qk_onnx_model",
|
||||
required=False,
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model which consumes cross_qk outputs.",
|
||||
)
|
||||
|
||||
optional_outputs.add_argument(
|
||||
"--output_no_speech_probs",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
|
||||
)
|
||||
optional_outputs.set_defaults(output_no_speech_probs=False)
|
||||
|
||||
###################################
|
||||
# Quantization options for Whisper
|
||||
###################################
|
||||
|
||||
quant_args.add_argument(
|
||||
"--quantize_embedding_layer",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Quantize MatMul, GEMM, and Gather.",
|
||||
)
|
||||
quant_args.set_defaults(quantize_embedding_layer=False)
|
||||
|
||||
quant_args.add_argument(
|
||||
"--quantize_per_channel",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Quantize weights per each channel.",
|
||||
)
|
||||
quant_args.set_defaults(quantize_per_channel=False)
|
||||
|
||||
quant_args.add_argument(
|
||||
"--quantize_reduce_range",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Quantize weights with 7 bits.",
|
||||
)
|
||||
quant_args.set_defaults(quantize_reduce_range=False)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# Collect cross QKs if either flag is enabled
|
||||
args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk
|
||||
|
||||
# FP32 CPU can be supported here once the DMMHA CPU kernel bugs are fixed
|
||||
args.use_decoder_masked_mha = args.use_decoder_masked_mha and args.provider == "cuda"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def export_onnx_models(
|
||||
model_name_or_path,
|
||||
model_impl,
|
||||
cache_dir,
|
||||
output_dir,
|
||||
use_gpu,
|
||||
use_external_data_format,
|
||||
optimize_onnx,
|
||||
precision,
|
||||
verbose,
|
||||
use_forced_decoder_ids: bool = False,
|
||||
merge_encoder_and_decoder_init: bool = True,
|
||||
no_beam_search_op: bool = False,
|
||||
use_decoder_masked_mha: bool = False,
|
||||
output_qk: bool = False,
|
||||
overwrite: bool = False,
|
||||
use_int32_inputs: bool = True,
|
||||
quantize_embedding_layer: bool = False,
|
||||
quantize_per_channel: bool = False,
|
||||
quantize_reduce_range: bool = False,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
device = torch.device("cuda" if use_gpu else "cpu")
|
||||
|
||||
models = WhisperHelper.load_model(
|
||||
model_name_or_path,
|
||||
model_impl,
|
||||
cache_dir,
|
||||
device,
|
||||
torch.float16 if precision == Precision.FLOAT16 else torch.float32,
|
||||
merge_encoder_and_decoder_init,
|
||||
no_beam_search_op,
|
||||
output_qk,
|
||||
)
|
||||
config = models["decoder"].config
|
||||
|
||||
if (not use_external_data_format) and (config.num_hidden_layers > 24):
|
||||
logger.warning("You MUST pass `--use_external_data_format` because model size > 2GB")
|
||||
raise Exception("Please pass `--use_external_data_format` for this model.")
|
||||
|
||||
output_paths = []
|
||||
for name, model in models.items():
|
||||
print(f"========> Handling {name} model......")
|
||||
filename_suffix = "_" + name
|
||||
|
||||
onnx_path = WhisperHelper.get_onnx_path(
|
||||
output_dir,
|
||||
model_name_or_path,
|
||||
suffix=filename_suffix,
|
||||
new_folder=False,
|
||||
)
|
||||
|
||||
# Export to ONNX
|
||||
if overwrite or not os.path.exists(onnx_path):
|
||||
logger.info(f"Exporting ONNX model to {onnx_path}")
|
||||
WhisperHelper.export_onnx(
|
||||
model,
|
||||
onnx_path,
|
||||
PROVIDERS[provider],
|
||||
verbose,
|
||||
use_external_data_format,
|
||||
use_fp16_inputs=(precision == Precision.FLOAT16),
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
use_encoder_hidden_states=(name == "decoder_init"),
|
||||
use_kv_cache_inputs=(name == "decoder"),
|
||||
)
|
||||
else:
|
||||
logger.info(f"Skip exporting: existing ONNX model {onnx_path}")
|
||||
|
||||
# Optimize ONNX model
|
||||
if optimize_onnx or precision != Precision.FLOAT32:
|
||||
output_path = WhisperHelper.get_onnx_path(
|
||||
output_dir,
|
||||
model_name_or_path,
|
||||
suffix=filename_suffix + "_" + str(precision),
|
||||
new_folder=False,
|
||||
)
|
||||
|
||||
if overwrite or not os.path.exists(output_path):
|
||||
if optimize_onnx:
|
||||
logger.info(f"Optimizing model to {output_path}")
|
||||
WhisperHelper.optimize_onnx(
|
||||
onnx_path,
|
||||
output_path,
|
||||
precision == Precision.FLOAT16,
|
||||
model.config.encoder_attention_heads,
|
||||
model.config.d_model,
|
||||
model.config.decoder_layers,
|
||||
use_external_data_format,
|
||||
use_gpu=use_gpu,
|
||||
provider=provider,
|
||||
is_decoder=(name == "decoder"),
|
||||
no_beam_search_op=no_beam_search_op,
|
||||
use_decoder_masked_mha=use_decoder_masked_mha,
|
||||
output_qk=output_qk,
|
||||
)
|
||||
# Remove old ONNX model and old data file
|
||||
if os.path.exists(onnx_path):
|
||||
os.remove(onnx_path)
|
||||
if os.path.exists(onnx_path + ".data"):
|
||||
os.remove(onnx_path + ".data")
|
||||
onnx_path = output_path
|
||||
|
||||
if isinstance(model, WhisperEncoder):
|
||||
model.verify_onnx(
|
||||
onnx_path,
|
||||
PROVIDERS[provider],
|
||||
use_fp16_inputs=(precision == Precision.FLOAT16),
|
||||
)
|
||||
else:
|
||||
model.verify_onnx(
|
||||
onnx_path,
|
||||
PROVIDERS[provider],
|
||||
use_fp16_inputs=(precision == Precision.FLOAT16),
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
|
||||
if precision == Precision.INT8:
|
||||
quantization.quantize_dynamic(
|
||||
onnx_path,
|
||||
output_path,
|
||||
op_types_to_quantize=(
|
||||
["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"]
|
||||
),
|
||||
use_external_data_format=use_external_data_format,
|
||||
per_channel=quantize_per_channel,
|
||||
reduce_range=quantize_reduce_range,
|
||||
extra_options={"MatMulConstBOnly": True},
|
||||
)
|
||||
else:
|
||||
logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
|
||||
else:
|
||||
output_path = onnx_path
|
||||
|
||||
output_paths.append(output_path)
|
||||
|
||||
return output_paths
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
args = parse_arguments(argv)
|
||||
|
||||
setup_logger(args.verbose)
|
||||
|
||||
logger.info(f"Arguments:{args}")
|
||||
|
||||
cache_dir = args.cache_dir
|
||||
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
|
||||
prepare_environment(cache_dir, output_dir, args.use_gpu)
|
||||
|
||||
if args.precision == Precision.FLOAT16:
|
||||
assert args.use_gpu, "fp16 requires --use_gpu"
|
||||
|
||||
output_paths = export_onnx_models(
|
||||
args.model_name_or_path,
|
||||
args.model_impl,
|
||||
cache_dir,
|
||||
output_dir,
|
||||
args.use_gpu,
|
||||
args.use_external_data_format,
|
||||
args.optimize_onnx,
|
||||
args.precision,
|
||||
args.verbose,
|
||||
args.use_forced_decoder_ids,
|
||||
not args.separate_encoder_and_decoder_init,
|
||||
args.no_beam_search_op,
|
||||
args.use_decoder_masked_mha,
|
||||
args.output_cross_qk,
|
||||
args.overwrite,
|
||||
not args.use_int64_inputs,
|
||||
args.quantize_embedding_layer,
|
||||
args.quantize_per_channel,
|
||||
args.quantize_reduce_range,
|
||||
args.provider,
|
||||
)
|
||||
|
||||
max_diff = 0
|
||||
if not args.no_beam_search_op:
|
||||
logger.info("Chaining model ... :")
|
||||
args.beam_model_output_dir = WhisperHelper.get_onnx_path(
|
||||
output_dir,
|
||||
args.model_name_or_path,
|
||||
suffix="_beamsearch",
|
||||
new_folder=False,
|
||||
)
|
||||
for path in output_paths:
|
||||
if "encoder_decoder" in path or "encoder" in path:
|
||||
args.encoder_path = path
|
||||
elif "decoder" in path:
|
||||
args.decoder_path = path
|
||||
chain_model(args)
|
||||
output_paths.append(args.beam_model_output_dir)
|
||||
|
||||
# Check chained model
|
||||
ort_session = create_onnxruntime_session(
|
||||
args.beam_model_output_dir,
|
||||
use_gpu=args.use_gpu,
|
||||
provider=args.provider,
|
||||
)
|
||||
device = torch.device("cuda" if args.use_gpu else "cpu")
|
||||
|
||||
# Wrap parity check in try-except to allow export to continue in case this produces an error
|
||||
try:
|
||||
with torch.no_grad():
|
||||
# Verify batched decoding with prompts for OpenAI implementation
|
||||
if args.model_impl == "openai" and args.use_forced_decoder_ids:
|
||||
max_diff = WhisperHelper.verify_onnx(
|
||||
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
|
||||
)
|
||||
else:
|
||||
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
|
||||
if max_diff > 1e-4:
|
||||
logger.warning("PyTorch and ONNX Runtime results are NOT close")
|
||||
else:
|
||||
logger.info("PyTorch and ONNX Runtime results are close")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# Remove extra ONNX models saved in output directory
|
||||
for _file in os.listdir(output_dir):
|
||||
if "_beamsearch" not in _file and "_jump_times" not in _file:
|
||||
path = os.path.join(output_dir, _file)
|
||||
os.remove(path)
|
||||
if path in output_paths:
|
||||
output_paths.remove(path)
|
||||
|
||||
else:
|
||||
# Create ancillary JSON files for ONNX Runtime GenAI and/or Hugging Face's Optimum
|
||||
WhisperHelper.save_processing(
|
||||
args.model_name_or_path,
|
||||
args.provider,
|
||||
args.separate_encoder_and_decoder_init,
|
||||
args.use_decoder_masked_mha,
|
||||
args.output_cross_qk,
|
||||
next(iter(filter(lambda path: "encoder" in path, output_paths))),
|
||||
next(iter(filter(lambda path: "decoder" in path, output_paths))),
|
||||
output_dir,
|
||||
cache_dir,
|
||||
)
|
||||
|
||||
logger.info(f"Done! Outputs: {output_paths}")
|
||||
return max_diff
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+331
@@ -0,0 +1,331 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import onnx
|
||||
from benchmark_helper import Precision
|
||||
from convert_generation import (
|
||||
get_shared_initializers,
|
||||
update_decoder_subgraph_output_cross_attention,
|
||||
update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
|
||||
)
|
||||
from onnx import TensorProto, helper
|
||||
from transformers import WhisperConfig, WhisperTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def verify_inputs(beam_inputs, graph_inputs):
|
||||
# Verify that ONNX graph's inputs match beam search op's inputs
|
||||
beam_required_inputs = list(filter(lambda beam_input: beam_input, beam_inputs))
|
||||
assert len(graph_inputs) == len(beam_required_inputs)
|
||||
for graph_input, beam_input in zip(graph_inputs, beam_required_inputs, strict=False):
|
||||
# Check if graph_input is in beam_input to handle beam_input names with the "_fp16" suffix
|
||||
assert graph_input.name in beam_input
|
||||
|
||||
|
||||
def clean_list(arr, remove_all_strings=True):
|
||||
if remove_all_strings:
|
||||
# Remove all empty strings in list
|
||||
return list(filter(lambda elm: elm != "", arr))
|
||||
|
||||
# Remove empty strings at end of list
|
||||
while len(arr) > 0:
|
||||
if arr[-1] == "":
|
||||
arr.pop()
|
||||
else:
|
||||
break
|
||||
return arr
|
||||
|
||||
|
||||
def chain_model(args):
|
||||
# Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op
|
||||
encoder_model = onnx.load_model(args.encoder_path, load_external_data=True)
|
||||
encoder_model.graph.name = "encoderdecoderinit subgraph"
|
||||
|
||||
decoder_model = onnx.load_model(args.decoder_path, load_external_data=True)
|
||||
decoder_model.graph.name = "decoder subgraph"
|
||||
|
||||
config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
|
||||
# Create inputs/outputs for WhisperBeamSearch op
|
||||
temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
|
||||
beam_inputs = [
|
||||
"input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features",
|
||||
"max_length",
|
||||
"min_length",
|
||||
"num_beams",
|
||||
"num_return_sequences",
|
||||
"length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty",
|
||||
"repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty",
|
||||
"vocab_mask" if args.use_vocab_mask else "",
|
||||
"prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
|
||||
"", # attention mask
|
||||
"decoder_input_ids" if args.use_forced_decoder_ids else "",
|
||||
"logits_processor" if args.use_logits_processor else "",
|
||||
"cross_qk_layer_head" if args.collect_cross_qk else "",
|
||||
"extra_decoding_ids" if args.extra_decoding_ids else "",
|
||||
temperature_name if args.use_temperature else "",
|
||||
]
|
||||
|
||||
sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores"
|
||||
scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores"
|
||||
beam_outputs = [
|
||||
"sequences",
|
||||
sequence_scores_name if args.output_sequence_scores else "",
|
||||
scores_name if args.output_scores else "",
|
||||
"cross_qk" if args.collect_cross_qk else "",
|
||||
"no_speech_probs_beam" if args.output_no_speech_probs else "",
|
||||
]
|
||||
|
||||
graph_nodes = []
|
||||
if args.precision == Precision.FLOAT16:
|
||||
input_features_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=["input_features"],
|
||||
outputs=["input_features_fp16"],
|
||||
name="CastInputFeaturesToFp16",
|
||||
to=TensorProto.FLOAT16,
|
||||
)
|
||||
len_pen_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=["length_penalty"],
|
||||
outputs=["length_penalty_fp16"],
|
||||
name="CastLengthPenaltyToFp16",
|
||||
to=TensorProto.FLOAT16,
|
||||
)
|
||||
rep_pen_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=["repetition_penalty"],
|
||||
outputs=["repetition_penalty_fp16"],
|
||||
name="CastRepetitionPenaltyToFp16",
|
||||
to=TensorProto.FLOAT16,
|
||||
)
|
||||
graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node])
|
||||
|
||||
if args.use_temperature:
|
||||
temp_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=["temperature"],
|
||||
outputs=["temperature_fp16"],
|
||||
name="temperature_to_fp16",
|
||||
to=TensorProto.FLOAT16,
|
||||
)
|
||||
graph_nodes.append(temp_cast_node)
|
||||
|
||||
if args.output_sequence_scores:
|
||||
output_sequence_scores_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=["sequence_scores_fp16"],
|
||||
outputs=["sequence_scores"],
|
||||
name="CastOutputSequenceScoresToFp32",
|
||||
to=TensorProto.FLOAT,
|
||||
)
|
||||
graph_nodes.append(output_sequence_scores_cast_node)
|
||||
|
||||
if args.output_scores:
|
||||
output_scores_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=["scores_fp16"],
|
||||
outputs=["scores"],
|
||||
name="CastScoresToFp32",
|
||||
to=TensorProto.FLOAT,
|
||||
)
|
||||
graph_nodes.append(output_scores_cast_node)
|
||||
|
||||
# Create WhisperBeamSearch op
|
||||
beam_search_attrs = [
|
||||
helper.make_attribute("eos_token_id", config.eos_token_id),
|
||||
helper.make_attribute("pad_token_id", config.pad_token_id),
|
||||
helper.make_attribute(
|
||||
"decoder_start_token_id", config.decoder_start_token_id
|
||||
), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0]
|
||||
helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]),
|
||||
helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]),
|
||||
helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]),
|
||||
(
|
||||
helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0])
|
||||
if args.output_no_speech_probs
|
||||
else ""
|
||||
),
|
||||
helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]),
|
||||
helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]),
|
||||
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
helper.make_attribute("early_stopping", True),
|
||||
helper.make_attribute("model_type", 2),
|
||||
helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "",
|
||||
]
|
||||
node = helper.make_node(
|
||||
"WhisperBeamSearch",
|
||||
inputs=clean_list(beam_inputs, remove_all_strings=False),
|
||||
outputs=clean_list(beam_outputs, remove_all_strings=False),
|
||||
name="BeamSearch",
|
||||
domain="com.microsoft",
|
||||
)
|
||||
node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True))
|
||||
|
||||
# Graph inputs
|
||||
input_features = helper.make_tensor_value_info(
|
||||
"input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
|
||||
)
|
||||
max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
|
||||
min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
|
||||
num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
|
||||
num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
|
||||
length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
|
||||
repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
|
||||
vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size])
|
||||
prefix_vocab_mask = helper.make_tensor_value_info(
|
||||
"prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size]
|
||||
)
|
||||
decoder_input_ids = helper.make_tensor_value_info(
|
||||
"decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"]
|
||||
)
|
||||
logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
|
||||
cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2])
|
||||
extra_decoding_ids = helper.make_tensor_value_info(
|
||||
"extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"]
|
||||
)
|
||||
temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
|
||||
|
||||
graph_inputs = clean_list(
|
||||
[
|
||||
input_features,
|
||||
max_length,
|
||||
min_length,
|
||||
num_beams,
|
||||
num_return_sequences,
|
||||
length_penalty,
|
||||
repetition_penalty,
|
||||
vocab_mask if args.use_vocab_mask else "",
|
||||
prefix_vocab_mask if args.use_prefix_vocab_mask else "",
|
||||
decoder_input_ids if args.use_forced_decoder_ids else "",
|
||||
logits_processor if args.use_logits_processor else "",
|
||||
cross_qk_layer_head if args.collect_cross_qk else "",
|
||||
extra_decoding_ids if args.extra_decoding_ids else "",
|
||||
temperature if args.use_temperature else "",
|
||||
]
|
||||
)
|
||||
|
||||
# Graph outputs
|
||||
sequences = helper.make_tensor_value_info(
|
||||
"sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
|
||||
)
|
||||
sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"])
|
||||
scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"])
|
||||
cross_qk = helper.make_tensor_value_info(
|
||||
"cross_qk",
|
||||
TensorProto.FLOAT,
|
||||
["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"],
|
||||
)
|
||||
no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"])
|
||||
|
||||
graph_outputs = clean_list(
|
||||
[
|
||||
sequences,
|
||||
sequence_scores if args.output_sequence_scores else "",
|
||||
scores if args.output_scores else "",
|
||||
cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "",
|
||||
no_speech_probs if args.output_no_speech_probs else "",
|
||||
]
|
||||
)
|
||||
|
||||
# Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference
|
||||
if hasattr(args, "use_gpu") and args.use_gpu:
|
||||
if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
|
||||
logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!")
|
||||
else:
|
||||
logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph")
|
||||
if hasattr(args, "collect_cross_qk") and args.collect_cross_qk:
|
||||
update_decoder_subgraph_output_cross_attention(decoder_model.graph)
|
||||
|
||||
# Initializers/opsets
|
||||
# Delete shared data between decoder/encoder and move to larger graph initializers
|
||||
initializers = get_shared_initializers(encoder_model, decoder_model)
|
||||
node.attribute.extend(
|
||||
[
|
||||
helper.make_attribute("decoder", decoder_model.graph),
|
||||
helper.make_attribute("encoder", encoder_model.graph),
|
||||
]
|
||||
)
|
||||
|
||||
opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
|
||||
|
||||
graph_nodes.append(node)
|
||||
if args.output_no_speech_probs:
|
||||
prob_cast_node = helper.make_node(
|
||||
"Cast",
|
||||
inputs=["no_speech_probs_beam"],
|
||||
outputs=["no_speech_probs"],
|
||||
name="no_speech_probs_cast_to_fp32",
|
||||
to=TensorProto.FLOAT,
|
||||
)
|
||||
graph_nodes.append(prob_cast_node)
|
||||
|
||||
# Make graph with WhisperBeamSearch op
|
||||
beam_graph = helper.make_graph(
|
||||
graph_nodes,
|
||||
name="WhisperBeamSearch Graph",
|
||||
inputs=graph_inputs,
|
||||
outputs=graph_outputs,
|
||||
initializer=initializers,
|
||||
)
|
||||
beam_graph_input_names = [gi.name for gi in graph_inputs]
|
||||
beam_graph_output_names = [go.name for go in graph_outputs]
|
||||
|
||||
if args.cross_qk_onnx_model:
|
||||
post_qk_model = onnx.load_model(args.cross_qk_onnx_model, load_external_data=True)
|
||||
post_qk_graph = post_qk_model.graph
|
||||
beam_graph.initializer.extend(post_qk_graph.initializer)
|
||||
beam_graph.node.extend(post_qk_graph.node)
|
||||
# If tensor from cross_qk_onnx_model has same name as tensor in beamsearch graph, treat them as same tensor.
|
||||
# User should notice this rule when provide cross_qk_onnx_model to append to the beamsearch node.
|
||||
for pgi in post_qk_graph.input:
|
||||
if (
|
||||
(pgi.name not in beam_graph_input_names)
|
||||
and (pgi.name not in beam_graph_output_names)
|
||||
and (pgi.name != "cross_qk")
|
||||
):
|
||||
beam_graph.input.extend([pgi])
|
||||
beam_graph.output.extend(post_qk_graph.output)
|
||||
|
||||
# Verify graph's inputs match beam search's inputs
|
||||
verify_inputs(beam_inputs, graph_inputs)
|
||||
|
||||
assert decoder_model.ir_version == encoder_model.ir_version
|
||||
logger.info(f"Using IR version {decoder_model.ir_version} for chained model")
|
||||
|
||||
# Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model
|
||||
beam_model = helper.make_model_gen_version(
|
||||
beam_graph,
|
||||
producer_name="onnxruntime.transformers",
|
||||
opset_imports=opset_import,
|
||||
ir_version=decoder_model.ir_version,
|
||||
)
|
||||
|
||||
# Save WhisperBeamSearch graph and external data
|
||||
if os.path.isfile(args.beam_model_output_dir):
|
||||
logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}")
|
||||
if os.path.exists(args.beam_model_output_dir):
|
||||
os.remove(args.beam_model_output_dir)
|
||||
if os.path.exists(args.beam_model_output_dir + ".data"):
|
||||
os.remove(args.beam_model_output_dir + ".data")
|
||||
|
||||
onnx.save(
|
||||
beam_model,
|
||||
args.beam_model_output_dir,
|
||||
save_as_external_data=args.use_external_data_format,
|
||||
all_tensors_to_one_file=True,
|
||||
convert_attribute=True,
|
||||
location=f"{os.path.basename(args.beam_model_output_dir)}.data",
|
||||
)
|
||||
try:
|
||||
onnx.checker.check_model(args.beam_model_output_dir, full_check=True)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while running the ONNX checker: {e}", exc_info=True) # noqa: G201
|
||||
+464
@@ -0,0 +1,464 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from float16 import convert_float_to_float16
|
||||
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
|
||||
from onnx import ModelProto, ValueInfoProto
|
||||
from onnx_model import OnnxModel
|
||||
from past_helper import PastKeyValuesHelper
|
||||
from transformers import WhisperConfig
|
||||
from whisper_inputs import (
|
||||
convert_inputs_for_ort,
|
||||
get_model_dynamic_axes,
|
||||
get_sample_decoder_inputs,
|
||||
group_past_key_values,
|
||||
)
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhisperDecoder(torch.nn.Module):
|
||||
"""A Whisper decoder with optional past key values"""
|
||||
|
||||
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.device = model.device
|
||||
self.model_impl = model_impl
|
||||
self.no_beam_search_op = no_beam_search_op
|
||||
|
||||
self.decoder = None if model_impl == "openai" else model.model.decoder
|
||||
self.proj_out = None if model_impl == "openai" else model.proj_out
|
||||
self.model = model if model_impl == "openai" else None
|
||||
|
||||
self.max_source_positions = self.config.max_source_positions
|
||||
self.num_heads = self.config.decoder_attention_heads
|
||||
self.head_size = self.config.d_model // self.num_heads
|
||||
|
||||
def hf_forward(
|
||||
self,
|
||||
decoder_input_ids: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
past_key_values: list[tuple[torch.Tensor]] | None = None,
|
||||
):
|
||||
outputs = self.decoder(
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
input_ids=decoder_input_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
logits = self.proj_out(outputs.last_hidden_state)
|
||||
present_key_values = outputs.past_key_values
|
||||
|
||||
if past_key_values is None:
|
||||
# Return present_self_* and present_cross_* for decoder-init
|
||||
return logits, present_key_values
|
||||
|
||||
# Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
|
||||
# (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
||||
# After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1), ...,
|
||||
# (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1), ...
|
||||
present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(present_key_values)
|
||||
|
||||
# Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
|
||||
return logits, present_self
|
||||
|
||||
def oai_forward(
|
||||
self,
|
||||
decoder_input_ids: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
past_key_values: list[tuple[torch.Tensor]] | None = None,
|
||||
):
|
||||
past_kv_cache = {}
|
||||
if past_key_values is not None:
|
||||
# Convert past KV caches (BxNxSxH --> BxSxNxH --> BxSxD) for OpenAI's forward pass
|
||||
self_attn_kv_caches, cross_attn_kv_caches = group_past_key_values(past_key_values)
|
||||
self_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in self_attn_kv_caches]
|
||||
self_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in self_attn_kv_caches]
|
||||
cross_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in cross_attn_kv_caches]
|
||||
cross_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in cross_attn_kv_caches]
|
||||
|
||||
for idx, block in enumerate(self.model.decoder.blocks):
|
||||
past_kv_cache[block.attn.key] = self_attn_kv_caches[2 * idx]
|
||||
past_kv_cache[block.attn.value] = self_attn_kv_caches[2 * idx + 1]
|
||||
past_kv_cache[block.cross_attn.key] = cross_attn_kv_caches[2 * idx]
|
||||
past_kv_cache[block.cross_attn.value] = cross_attn_kv_caches[2 * idx + 1]
|
||||
|
||||
# Install OpenAI's hooks on the forward pass of each nn.Linear for key and value
|
||||
# since the hooks will capture the output of the key and value MatMuls, which
|
||||
# represent the current keys and values.
|
||||
#
|
||||
# For OpenAI's forward pass, the hook function will also perform the concat
|
||||
# operation (past_kv + curr_kv --> pres_kv) if needed. However, the ONNX model
|
||||
# will not contain this concat operation because the present KV caches aren't
|
||||
# returned by OpenAI's forward pass.
|
||||
kv_cache, hooks = self.model.install_kv_cache_hooks()
|
||||
|
||||
# Run forward pass
|
||||
# NOTE: There is a bug with openai-whisper==20240930 with the introduction of SDPA.
|
||||
# In the Whisper codebase, the following line
|
||||
#
|
||||
# is_causal = mask is not None and n_ctx > 1
|
||||
#
|
||||
# has been added where `mask` is a torch tensor. The right-hand side evaluates to `tensor(True/False)`
|
||||
# but `is_causal` only accepts the boolean value. The fix is to apply `.item()` after the right-hand
|
||||
# side has been evaluated. In other words, the line should be
|
||||
#
|
||||
# is_causal = (mask is not None and n_ctx > 1).item()
|
||||
#
|
||||
# instead.
|
||||
logits = self.model.decoder(x=decoder_input_ids, xa=encoder_hidden_states, kv_cache=past_kv_cache)
|
||||
|
||||
# Re-do concat operation on self attention KV caches for ONNX export (if past self attention KV caches exist)
|
||||
if past_key_values is not None:
|
||||
for block in self.model.decoder.blocks:
|
||||
kv_cache[block.attn.key] = torch.cat(
|
||||
[past_kv_cache[block.attn.key], kv_cache[block.attn.key]], dim=1
|
||||
).detach()
|
||||
kv_cache[block.attn.value] = torch.cat(
|
||||
[past_kv_cache[block.attn.value], kv_cache[block.attn.value]], dim=1
|
||||
).detach()
|
||||
|
||||
present_self, present_cross = [], []
|
||||
for block in self.model.decoder.blocks:
|
||||
# Group self and cross values
|
||||
present_self.append(kv_cache[block.attn.key])
|
||||
present_self.append(kv_cache[block.attn.value])
|
||||
if past_key_values is None:
|
||||
# Return present_self_* and present_cross_* for decoder-init
|
||||
present_cross.append(kv_cache[block.cross_attn.key])
|
||||
present_cross.append(kv_cache[block.cross_attn.value])
|
||||
|
||||
# Convert present KV caches (BxSxD --> BxSxNxH --> BxNxSxH) after OpenAI's forward pass
|
||||
present_self = [
|
||||
present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
|
||||
for present_kv in present_self
|
||||
]
|
||||
present_cross = [
|
||||
present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
|
||||
for present_kv in present_cross
|
||||
]
|
||||
|
||||
# Remove OpenAI's hooks since they can persist after this function completes
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
if past_key_values is None:
|
||||
# Return present_self_* and present_cross_* for decoder-init
|
||||
present_key_values = PastKeyValuesHelper.group_by_layer(
|
||||
present_self + present_cross, len(present_self) // 2
|
||||
)
|
||||
return logits, present_key_values
|
||||
|
||||
# Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
|
||||
return logits, present_self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_input_ids: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
past_key_values: list[tuple[torch.Tensor]] | None = None,
|
||||
):
|
||||
if self.model_impl == "openai":
|
||||
return self.oai_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
|
||||
return self.hf_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
|
||||
|
||||
def input_names(self):
|
||||
if self.first_pass:
|
||||
input_names = ["input_ids", "encoder_hidden_states"]
|
||||
else:
|
||||
input_names = [
|
||||
"input_ids",
|
||||
"encoder_hidden_states",
|
||||
*list(
|
||||
chain.from_iterable(
|
||||
(f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}")
|
||||
for i in range(self.config.decoder_layers)
|
||||
)
|
||||
),
|
||||
]
|
||||
return input_names
|
||||
|
||||
def output_names(self):
|
||||
if self.first_pass:
|
||||
output_names = [
|
||||
"logits",
|
||||
*list(
|
||||
chain.from_iterable(
|
||||
(
|
||||
f"present_key_self_{i}",
|
||||
f"present_value_self_{i}",
|
||||
f"present_key_cross_{i}",
|
||||
f"present_value_cross_{i}",
|
||||
)
|
||||
for i in range(self.config.decoder_layers)
|
||||
)
|
||||
),
|
||||
]
|
||||
else:
|
||||
output_names = [
|
||||
"logits",
|
||||
*list(
|
||||
chain.from_iterable(
|
||||
(f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers)
|
||||
)
|
||||
),
|
||||
]
|
||||
return output_names
|
||||
|
||||
def dynamic_axes(self, input_names, output_names):
|
||||
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
|
||||
if "input_ids" in dynamic_axes and not self.no_beam_search_op:
|
||||
# Set dynamic axes for `input_ids` when using beam search op to {0: "batch_size"} only
|
||||
del dynamic_axes["input_ids"][1]
|
||||
return dynamic_axes
|
||||
|
||||
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
|
||||
inputs = get_sample_decoder_inputs(
|
||||
self.config,
|
||||
self.device,
|
||||
batch_size=2,
|
||||
past_sequence_length=(0 if self.first_pass else 6),
|
||||
sequence_length=(6 if self.first_pass else 1),
|
||||
use_fp16=use_fp16_inputs,
|
||||
use_int32=use_int32_inputs,
|
||||
)
|
||||
if return_dict:
|
||||
if self.first_pass:
|
||||
del inputs["past_key_values"]
|
||||
return inputs
|
||||
|
||||
if self.first_pass:
|
||||
return (
|
||||
inputs["decoder_input_ids"],
|
||||
inputs["encoder_hidden_states"],
|
||||
)
|
||||
return (
|
||||
inputs["decoder_input_ids"],
|
||||
inputs["encoder_hidden_states"],
|
||||
inputs["past_key_values"],
|
||||
)
|
||||
|
||||
def fix_key_value_cache_dims(self, io: ValueInfoProto, is_cross: bool = False, is_output: bool = False):
|
||||
# Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
|
||||
# and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
|
||||
num_heads = io.type.tensor_type.shape.dim[1]
|
||||
if "_dim_" in num_heads.dim_param:
|
||||
num_heads.Clear()
|
||||
num_heads.dim_value = self.num_heads
|
||||
sequence_length = io.type.tensor_type.shape.dim[2]
|
||||
if "_dim_" in sequence_length.dim_param:
|
||||
sequence_length.Clear()
|
||||
if is_cross:
|
||||
sequence_length.dim_value = self.max_source_positions
|
||||
else:
|
||||
sequence_length.dim_param = "total_sequence_length" if is_output else "past_sequence_length"
|
||||
head_size = io.type.tensor_type.shape.dim[3]
|
||||
if "_dim_" in head_size.dim_param:
|
||||
head_size.Clear()
|
||||
head_size.dim_value = self.head_size
|
||||
return io
|
||||
|
||||
def fix_io(self, io_list: RepeatedCompositeFieldContainer, is_output: bool = False):
|
||||
# Fix order of inputs/outputs and each dim_value of input/output
|
||||
reordered_io = []
|
||||
self_attn_kv_caches = []
|
||||
cross_attn_kv_caches = []
|
||||
|
||||
for io in io_list:
|
||||
if "past" not in io.name and "present" not in io.name:
|
||||
reordered_io.append(io)
|
||||
elif "self" in io.name:
|
||||
# Self attention KV caches
|
||||
new_io = self.fix_key_value_cache_dims(io, is_cross=False, is_output=is_output)
|
||||
if self.no_beam_search_op:
|
||||
reordered_io.append(new_io)
|
||||
else:
|
||||
self_attn_kv_caches.append(new_io)
|
||||
else:
|
||||
# Cross attention KV caches
|
||||
new_io = self.fix_key_value_cache_dims(io, is_cross=True, is_output=is_output)
|
||||
if self.no_beam_search_op:
|
||||
reordered_io.append(new_io)
|
||||
else:
|
||||
cross_attn_kv_caches.append(new_io)
|
||||
|
||||
if not self.no_beam_search_op:
|
||||
reordered_io += self_attn_kv_caches + cross_attn_kv_caches
|
||||
return reordered_io
|
||||
|
||||
def fix_inputs_and_outputs(self, model: ModelProto):
|
||||
# ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
|
||||
# We now change the dim_values to the correct one.
|
||||
reordered_inputs = self.fix_io(model.graph.input, is_output=False)
|
||||
while len(model.graph.input) > 0:
|
||||
model.graph.input.pop()
|
||||
model.graph.input.extend(reordered_inputs)
|
||||
|
||||
reordered_outputs = self.fix_io(model.graph.output, is_output=True)
|
||||
while len(model.graph.output) > 0:
|
||||
model.graph.output.pop()
|
||||
model.graph.output.extend(reordered_outputs)
|
||||
return model
|
||||
|
||||
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
|
||||
if self.model_impl == "openai" and use_fp16_inputs:
|
||||
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
|
||||
# float32 to float16 since exported model already has float16 weights everywhere
|
||||
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
|
||||
# when computing LayerNorm.
|
||||
#
|
||||
# Reference:
|
||||
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
|
||||
model = convert_float_to_float16(model)
|
||||
return model
|
||||
|
||||
def export_onnx(
|
||||
self,
|
||||
onnx_model_path: str,
|
||||
provider: str,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_fp16_inputs: bool = False,
|
||||
use_int32_inputs: bool = True,
|
||||
use_encoder_hidden_states: bool = False,
|
||||
use_kv_cache_inputs: bool = True,
|
||||
):
|
||||
"""Export decoder to ONNX
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path to save ONNX model
|
||||
provider (str): provider to use for verifying parity on ONNX model
|
||||
verbose (bool, optional): print verbose information. Defaults to True.
|
||||
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
||||
use_fp16_inputs (bool, optional): use float16 inputs for the KV caches. Defaults to False.
|
||||
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
|
||||
use_encoder_hidden_states (bool, optional): use encoder_hidden_states as model input for decoder-init/decoder-without-past models. Defaults to False.
|
||||
use_kv_cache_inputs (bool, optional): use KV caches as model inputs for decoder-with-past models. Defaults to True.
|
||||
"""
|
||||
# Shape of decoder's tensors:
|
||||
# Required Inputs:
|
||||
# decoder_input_ids: (batch_size, sequence_length)
|
||||
# Optional Inputs:
|
||||
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
|
||||
# past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
|
||||
# past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
||||
# Outputs:
|
||||
# logits: (batch_size, sequence_length, vocab_size)
|
||||
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
|
||||
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
||||
|
||||
# For the first pass through the decoder (i.e. decoder-init/decoder-without-past)
|
||||
self.first_pass = use_encoder_hidden_states and not use_kv_cache_inputs
|
||||
|
||||
# For subsequent passes through the decoder (i.e. decoder-with-past)
|
||||
self.later_pass = not use_encoder_hidden_states and use_kv_cache_inputs
|
||||
|
||||
assert self.first_pass or self.later_pass, (
|
||||
"Only one of `use_encoder_hidden_states` and `use_kv_cache_inputs` can be true at once."
|
||||
)
|
||||
|
||||
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
|
||||
input_names = self.input_names()
|
||||
output_names = self.output_names()
|
||||
dynamic_axes = self.dynamic_axes(input_names, output_names)
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
|
||||
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
|
||||
|
||||
torch.onnx.export(
|
||||
self,
|
||||
args=inputs,
|
||||
f=out_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
|
||||
model = self.fix_inputs_and_outputs(model)
|
||||
model = self.fix_layernorm_weights(model, use_fp16_inputs)
|
||||
OnnxModel.save(
|
||||
model,
|
||||
onnx_model_path,
|
||||
save_as_external_data=use_external_data_format,
|
||||
all_tensors_to_one_file=True,
|
||||
)
|
||||
|
||||
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
|
||||
|
||||
def verify_onnx(
|
||||
self,
|
||||
onnx_model_path: str,
|
||||
provider: str,
|
||||
use_fp16_inputs: bool,
|
||||
use_int32_inputs: bool,
|
||||
):
|
||||
"""Verify ONNX model outputs and PyTorch model outputs match
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path to save ONNX model
|
||||
provider (str): execution provider for ONNX model
|
||||
use_fp16_inputs (bool, optional): use float16 inputs for the KV caches
|
||||
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
|
||||
"""
|
||||
# Shape of decoder's tensors:
|
||||
# Required Inputs:
|
||||
# decoder_input_ids: (batch_size, sequence_length)
|
||||
# Optional Inputs:
|
||||
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
|
||||
# past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
|
||||
# past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
||||
# Outputs:
|
||||
# logits: (batch_size, sequence_length, vocab_size)
|
||||
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
|
||||
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
||||
|
||||
# Run PyTorch model
|
||||
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
|
||||
pt_outputs = []
|
||||
if self.first_pass:
|
||||
out = self.forward(**inputs)
|
||||
pt_outputs.append(out[0].detach().cpu().numpy())
|
||||
for present_key_value_layer in out[1]:
|
||||
for present_key_value in present_key_value_layer:
|
||||
pt_outputs.append(present_key_value.detach().cpu().numpy())
|
||||
else:
|
||||
out = self.forward(**inputs)
|
||||
pt_outputs.append(out[0].detach().cpu().numpy())
|
||||
for present_self_key_value in out[1]:
|
||||
pt_outputs.append(present_self_key_value.detach().cpu().numpy())
|
||||
|
||||
# Run ONNX model
|
||||
sess = InferenceSession(onnx_model_path, providers=[provider])
|
||||
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
|
||||
|
||||
# Calculate output difference
|
||||
try:
|
||||
for i, output_name in enumerate(self.output_names()):
|
||||
diff = np.abs(pt_outputs[i] - ort_outputs[i])
|
||||
logger.warning(f"Comparing {output_name}...")
|
||||
logger.warning(f"Max diff: {np.max(diff)}")
|
||||
except: # noqa: E722
|
||||
pass
|
||||
+164
@@ -0,0 +1,164 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from float16 import convert_float_to_float16
|
||||
from onnx import ModelProto
|
||||
from onnx_model import OnnxModel
|
||||
from transformers import WhisperConfig
|
||||
from whisper_inputs import get_model_dynamic_axes, get_sample_encoder_inputs
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhisperEncoder(torch.nn.Module):
|
||||
"""Whisper encoder component"""
|
||||
|
||||
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.device = model.device
|
||||
self.model_impl = model_impl
|
||||
|
||||
self.encoder = model.encoder if model_impl == "openai" else model.model.encoder
|
||||
|
||||
def forward(self, audio_features: torch.Tensor):
|
||||
outputs = self.encoder(audio_features)
|
||||
return outputs if self.model_impl == "openai" else outputs.last_hidden_state
|
||||
|
||||
def input_names(self):
|
||||
input_names = ["audio_features"]
|
||||
return input_names
|
||||
|
||||
def output_names(self):
|
||||
output_names = ["encoder_hidden_states"]
|
||||
return output_names
|
||||
|
||||
def dynamic_axes(self, input_names, output_names):
|
||||
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
|
||||
return dynamic_axes
|
||||
|
||||
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
|
||||
if self.model_impl == "openai" and use_fp16_inputs:
|
||||
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
|
||||
# float32 to float16 since exported model already has float16 weights everywhere
|
||||
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
|
||||
# when computing LayerNorm.
|
||||
#
|
||||
# Reference:
|
||||
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
|
||||
model = convert_float_to_float16(model)
|
||||
return model
|
||||
|
||||
def export_onnx(
|
||||
self,
|
||||
onnx_model_path: str,
|
||||
provider: str,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_fp16_inputs: bool = False,
|
||||
):
|
||||
"""Export encoder to ONNX
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path to save ONNX model
|
||||
provider (str): provider to use for verifying parity on ONNX model
|
||||
verbose (bool, optional): print verbose information. Defaults to True.
|
||||
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
||||
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
|
||||
"""
|
||||
# Shape of encoder's tensors:
|
||||
# Inputs:
|
||||
# audio_features: (batch_size, num_mels, num_frames)
|
||||
# Outputs:
|
||||
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
|
||||
|
||||
inputs = get_sample_encoder_inputs(
|
||||
self.config,
|
||||
self.device,
|
||||
batch_size=2,
|
||||
use_fp16=use_fp16_inputs,
|
||||
)
|
||||
|
||||
input_names = self.input_names()
|
||||
output_names = self.output_names()
|
||||
dynamic_axes = self.dynamic_axes(input_names, output_names)
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
|
||||
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
|
||||
|
||||
torch.onnx.export(
|
||||
self,
|
||||
args=(inputs["audio_features"]),
|
||||
f=out_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
|
||||
model = self.fix_layernorm_weights(model, use_fp16_inputs)
|
||||
OnnxModel.save(
|
||||
model,
|
||||
onnx_model_path,
|
||||
save_as_external_data=use_external_data_format,
|
||||
all_tensors_to_one_file=True,
|
||||
)
|
||||
|
||||
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs)
|
||||
|
||||
def verify_onnx(
|
||||
self,
|
||||
onnx_model_path: str,
|
||||
provider: str,
|
||||
use_fp16_inputs: bool,
|
||||
):
|
||||
"""Verify ONNX model outputs and PyTorch model outputs match
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path to save ONNX model
|
||||
provider (str): execution provider for ONNX model
|
||||
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
|
||||
"""
|
||||
# Shape of encoder's tensors:
|
||||
# Inputs:
|
||||
# audio_features: (batch_size, num_mels, num_frames)
|
||||
# Outputs:
|
||||
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
|
||||
inputs = get_sample_encoder_inputs(
|
||||
self.config,
|
||||
self.device,
|
||||
batch_size=2,
|
||||
use_fp16=use_fp16_inputs,
|
||||
)
|
||||
|
||||
# Run PyTorch model
|
||||
pt_outputs = self.forward(inputs["audio_features"]).detach().cpu().numpy()
|
||||
|
||||
# Run ONNX model
|
||||
sess = InferenceSession(onnx_model_path, providers=[provider])
|
||||
ort_outputs = sess.run(None, {"audio_features": inputs["audio_features"].detach().cpu().numpy()})[0]
|
||||
|
||||
# Calculate output difference
|
||||
diff = np.abs(pt_outputs - ort_outputs)
|
||||
logger.warning("Comparing encoder_hidden_states...")
|
||||
logger.warning(f"Max diff: {np.max(diff)}")
|
||||
+371
@@ -0,0 +1,371 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from float16 import convert_float_to_float16
|
||||
from onnx import ModelProto, ValueInfoProto
|
||||
from onnx_model import OnnxModel
|
||||
from transformers import WhisperConfig
|
||||
from whisper_decoder import WhisperDecoder
|
||||
from whisper_encoder import WhisperEncoder
|
||||
from whisper_inputs import (
|
||||
convert_inputs_for_ort,
|
||||
get_model_dynamic_axes,
|
||||
get_sample_encoder_decoder_init_inputs,
|
||||
group_past_key_values,
|
||||
)
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WhisperEncoderDecoderInit(torch.nn.Module):
|
||||
"""Whisper encoder component + first pass through Whisper decoder component to initialize KV caches"""
|
||||
|
||||
def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.device = model.device
|
||||
self.model_impl = model_impl
|
||||
self.no_beam_search_op = no_beam_search_op
|
||||
|
||||
self.encoder = WhisperEncoder(config, model, model_impl)
|
||||
self.decoder = WhisperDecoder(config, model, model_impl, no_beam_search_op)
|
||||
|
||||
self.max_source_positions = self.config.max_source_positions
|
||||
self.num_heads = self.config.decoder_attention_heads
|
||||
self.head_size = self.config.d_model // self.num_heads
|
||||
|
||||
def hf_forward_for_beam_search_op(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor):
|
||||
encoder_hidden_states = self.encoder(audio_features)
|
||||
logits, present_key_values = self.decoder(decoder_input_ids, encoder_hidden_states)
|
||||
return logits, encoder_hidden_states, present_key_values
|
||||
|
||||
def hf_forward_for_no_beam_search_op(self, audio_features: torch.Tensor):
|
||||
encoder_hidden_states = self.encoder(audio_features)
|
||||
|
||||
# Get cross attention KV caches and return them for this model
|
||||
# We do this because these MatMuls are only run once before their outputs are being re-used in the decoder
|
||||
present_cross_attention_key_value_caches = []
|
||||
for layer in self.decoder.decoder.layers:
|
||||
cross_attn_key_cache = (
|
||||
layer.encoder_attn.k_proj(encoder_hidden_states)
|
||||
.view(-1, self.max_source_positions, self.num_heads, self.head_size)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
cross_attn_value_cache = (
|
||||
layer.encoder_attn.v_proj(encoder_hidden_states)
|
||||
.view(-1, self.max_source_positions, self.num_heads, self.head_size)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
present_cross_attention_key_value_caches.append(cross_attn_key_cache)
|
||||
present_cross_attention_key_value_caches.append(cross_attn_value_cache)
|
||||
|
||||
return encoder_hidden_states, present_cross_attention_key_value_caches
|
||||
|
||||
def oai_forward_for_beam_search_op(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor):
|
||||
encoder_hidden_states = self.encoder(audio_features)
|
||||
logits, present_key_values = self.decoder(decoder_input_ids, encoder_hidden_states)
|
||||
return logits, encoder_hidden_states, present_key_values
|
||||
|
||||
def oai_forward_for_no_beam_search_op(self, audio_features: torch.Tensor):
|
||||
encoder_hidden_states = self.encoder(audio_features)
|
||||
|
||||
# Get cross attention KV caches and return them for this model
|
||||
# We do this because these MatMuls are only run once before their outputs are being re-used in the decoder
|
||||
present_cross_attention_key_value_caches = []
|
||||
for block in self.decoder.model.decoder.blocks:
|
||||
cross_attn_key_cache = (
|
||||
block.cross_attn.key(encoder_hidden_states)
|
||||
.view(-1, self.max_source_positions, self.num_heads, self.head_size)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
cross_attn_value_cache = (
|
||||
block.cross_attn.value(encoder_hidden_states)
|
||||
.view(-1, self.max_source_positions, self.num_heads, self.head_size)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
present_cross_attention_key_value_caches.append(cross_attn_key_cache)
|
||||
present_cross_attention_key_value_caches.append(cross_attn_value_cache)
|
||||
|
||||
return encoder_hidden_states, present_cross_attention_key_value_caches
|
||||
|
||||
def forward(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor | None = None):
|
||||
if self.model_impl == "openai":
|
||||
if self.no_beam_search_op:
|
||||
return self.oai_forward_for_no_beam_search_op(audio_features)
|
||||
return self.oai_forward_for_beam_search_op(audio_features, decoder_input_ids)
|
||||
|
||||
# Hugging Face implementation
|
||||
if self.no_beam_search_op:
|
||||
return self.hf_forward_for_no_beam_search_op(audio_features)
|
||||
return self.hf_forward_for_beam_search_op(audio_features, decoder_input_ids)
|
||||
|
||||
def input_names(self):
|
||||
if self.no_beam_search_op:
|
||||
input_names = ["audio_features"]
|
||||
else:
|
||||
input_names = ["encoder_input_ids", "decoder_input_ids"]
|
||||
return input_names
|
||||
|
||||
def output_names(self):
|
||||
if self.no_beam_search_op:
|
||||
output_names = [
|
||||
"encoder_hidden_states",
|
||||
*list(
|
||||
chain.from_iterable(
|
||||
(f"present_key_cross_{i}", f"present_value_cross_{i}")
|
||||
for i in range(self.config.decoder_layers)
|
||||
)
|
||||
),
|
||||
]
|
||||
else:
|
||||
output_names = [
|
||||
"logits",
|
||||
"encoder_hidden_states",
|
||||
*list(
|
||||
chain.from_iterable(
|
||||
(
|
||||
f"present_key_self_{i}",
|
||||
f"present_value_self_{i}",
|
||||
f"present_key_cross_{i}",
|
||||
f"present_value_cross_{i}",
|
||||
)
|
||||
for i in range(self.config.decoder_layers)
|
||||
)
|
||||
),
|
||||
]
|
||||
return output_names
|
||||
|
||||
def dynamic_axes(self, input_names, output_names):
|
||||
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
|
||||
return dynamic_axes
|
||||
|
||||
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
|
||||
inputs = get_sample_encoder_decoder_init_inputs(
|
||||
self.config,
|
||||
self.device,
|
||||
batch_size=2,
|
||||
decoder_sequence_length=6,
|
||||
use_fp16=use_fp16_inputs,
|
||||
use_int32=use_int32_inputs,
|
||||
)
|
||||
if return_dict:
|
||||
if self.no_beam_search_op:
|
||||
del inputs["decoder_input_ids"]
|
||||
return inputs
|
||||
|
||||
if self.no_beam_search_op:
|
||||
return (inputs["audio_features"],)
|
||||
return (
|
||||
inputs["audio_features"],
|
||||
inputs["decoder_input_ids"],
|
||||
)
|
||||
|
||||
def fix_key_value_cache_dims(self, output: ValueInfoProto, is_cross: bool = False):
|
||||
# Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
|
||||
# and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
|
||||
num_heads = output.type.tensor_type.shape.dim[1]
|
||||
if "_dim_" in num_heads.dim_param:
|
||||
num_heads.Clear()
|
||||
num_heads.dim_value = self.num_heads
|
||||
sequence_length = output.type.tensor_type.shape.dim[2]
|
||||
if "_dim_" in sequence_length.dim_param:
|
||||
sequence_length.Clear()
|
||||
if is_cross:
|
||||
sequence_length.dim_value = self.max_source_positions
|
||||
else:
|
||||
sequence_length.dim_param = "total_sequence_length"
|
||||
head_size = output.type.tensor_type.shape.dim[3]
|
||||
if "_dim_" in head_size.dim_param:
|
||||
head_size.Clear()
|
||||
head_size.dim_value = self.head_size
|
||||
return output
|
||||
|
||||
def fix_outputs(self, model: ModelProto):
|
||||
# ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
|
||||
# We now change the dim_values to the correct one.
|
||||
reordered_outputs = []
|
||||
self_attn_kv_caches = []
|
||||
cross_attn_kv_caches = []
|
||||
|
||||
for output in model.graph.output:
|
||||
if "present" not in output.name:
|
||||
reordered_outputs.append(output)
|
||||
|
||||
elif "self" in output.name:
|
||||
# Self attention KV caches
|
||||
new_output = self.fix_key_value_cache_dims(output, is_cross=False)
|
||||
if self.no_beam_search_op:
|
||||
reordered_outputs.append(new_output)
|
||||
else:
|
||||
self_attn_kv_caches.append(new_output)
|
||||
else:
|
||||
# Cross attention KV caches
|
||||
new_output = self.fix_key_value_cache_dims(output, is_cross=True)
|
||||
if self.no_beam_search_op:
|
||||
reordered_outputs.append(new_output)
|
||||
else:
|
||||
cross_attn_kv_caches.append(new_output)
|
||||
|
||||
if not self.no_beam_search_op:
|
||||
reordered_outputs += self_attn_kv_caches + cross_attn_kv_caches
|
||||
|
||||
while len(model.graph.output) > 0:
|
||||
model.graph.output.pop()
|
||||
model.graph.output.extend(reordered_outputs)
|
||||
return model
|
||||
|
||||
def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
|
||||
if self.model_impl == "openai" and use_fp16_inputs:
|
||||
# Cast ONNX model to float16 to ensure LayerNorm weights are converted from
|
||||
# float32 to float16 since exported model already has float16 weights everywhere
|
||||
# except for LayerNorm ops. This happens because OpenAI always upcasts to float32
|
||||
# when computing LayerNorm.
|
||||
#
|
||||
# Reference:
|
||||
# https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
|
||||
model = convert_float_to_float16(model)
|
||||
return model
|
||||
|
||||
def export_onnx(
|
||||
self,
|
||||
onnx_model_path: str,
|
||||
provider: str,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_fp16_inputs: bool = False,
|
||||
use_int32_inputs: bool = True,
|
||||
):
|
||||
"""Export encoder-decoder-init to ONNX
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path to save ONNX model
|
||||
provider (str): provider to use for verifying parity on ONNX model
|
||||
verbose (bool, optional): print verbose information. Defaults to True.
|
||||
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
||||
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
|
||||
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
|
||||
"""
|
||||
# Shape of encoder's tensors:
|
||||
# Inputs:
|
||||
# audio_features: (batch_size, num_mels, num_frames)
|
||||
# Outputs:
|
||||
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
|
||||
|
||||
# Shape of decoder's tensors:
|
||||
# Inputs:
|
||||
# decoder_input_ids: (batch_size, sequence_length)
|
||||
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
|
||||
# Outputs:
|
||||
# logits: (batch_size, sequence_length, vocab_size)
|
||||
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
|
||||
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
||||
|
||||
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
|
||||
input_names = self.input_names()
|
||||
output_names = self.output_names()
|
||||
dynamic_axes = self.dynamic_axes(input_names, output_names)
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
|
||||
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
|
||||
|
||||
torch.onnx.export(
|
||||
self,
|
||||
args=inputs,
|
||||
f=out_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
|
||||
model = self.fix_outputs(model)
|
||||
model = self.fix_layernorm_weights(model, use_fp16_inputs)
|
||||
OnnxModel.save(
|
||||
model,
|
||||
onnx_model_path,
|
||||
save_as_external_data=use_external_data_format,
|
||||
all_tensors_to_one_file=True,
|
||||
)
|
||||
|
||||
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
|
||||
|
||||
def verify_onnx(
|
||||
self,
|
||||
onnx_model_path: str,
|
||||
provider: str,
|
||||
use_fp16_inputs: bool,
|
||||
use_int32_inputs: bool,
|
||||
):
|
||||
"""Verify ONNX model outputs and PyTorch model outputs match
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path to save ONNX model
|
||||
provider (str): execution provider for ONNX model
|
||||
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
|
||||
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
|
||||
"""
|
||||
# Shape of encoder's tensors:
|
||||
# Inputs:
|
||||
# audio_features: (batch_size, num_mels, num_frames)
|
||||
# Outputs:
|
||||
# encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
|
||||
|
||||
# Shape of decoder's tensors:
|
||||
# Inputs:
|
||||
# decoder_input_ids: (batch_size, sequence_length)
|
||||
# encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
|
||||
# Outputs:
|
||||
# logits: (batch_size, sequence_length, vocab_size)
|
||||
# present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
|
||||
# present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
|
||||
|
||||
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
|
||||
|
||||
# Run PyTorch model
|
||||
pt_outputs = []
|
||||
if self.no_beam_search_op:
|
||||
out = self.forward(**inputs)
|
||||
pt_outputs.append(out[0].detach().cpu().numpy())
|
||||
for present_cross_attn_cache in out[1]:
|
||||
pt_outputs.append(present_cross_attn_cache.detach().cpu().numpy())
|
||||
else:
|
||||
out = self.forward(**inputs)
|
||||
pt_outputs.append(out[0].detach().cpu().numpy())
|
||||
pt_outputs.append(out[1].detach().cpu().numpy())
|
||||
|
||||
(self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(out[2])
|
||||
pt_outputs.extend([self_attn_kv_cache.detach().cpu().numpy() for self_attn_kv_cache in self_attn_kv_caches])
|
||||
pt_outputs.extend(
|
||||
[cross_attn_kv_cache.detach().cpu().numpy() for cross_attn_kv_cache in cross_attn_kv_caches]
|
||||
)
|
||||
|
||||
# Run ONNX model
|
||||
sess = InferenceSession(onnx_model_path, providers=[provider])
|
||||
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
|
||||
|
||||
# Calculate output difference
|
||||
for i, output_name in enumerate(self.output_names()):
|
||||
diff = np.abs(pt_outputs[i] - ort_outputs[i])
|
||||
logger.warning(f"Comparing {output_name}...")
|
||||
logger.warning(f"Max diff: {np.max(diff)}")
|
||||
+1035
File diff suppressed because it is too large
Load Diff
+380
@@ -0,0 +1,380 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import WhisperConfig
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Create audio_features for encoder
|
||||
# Shape is (batch_size, feature_size, sequence_length) = (batch_size, num_mel_filters, num_frames)
|
||||
# where num_mel_filters is a model attribute and num_frames = (chunk_length * sample_rate) // hop_length.
|
||||
#
|
||||
# Hard-coded audio hyperparameters:
|
||||
# SAMPLE_RATE = 16000
|
||||
# N_FFT = 400
|
||||
# HOP_LENGTH = 160
|
||||
# CHUNK_LENGTH = 30 (i.e. 30-second chunk of audio)
|
||||
# N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE = 30 * 16000 = 480000 (i.e. 480,000 samples in a 30-second chunk of audio)
|
||||
# N_FRAMES = N_SAMPLES // HOP_LENGTH = 480000 // 160 = 3000 (i.e. 3000 frames in a mel spectrogram input)
|
||||
#
|
||||
# N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 = 160 * 2 = 320
|
||||
# FRAMES_PER_TOKEN = SAMPLE_RATE // HOP_LENGTH = 16000 // 160 = 100 (i.e. 10 ms per audio frame)
|
||||
# TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN = 16000 // 320 = 50 (i.e. 20 ms per audio token)
|
||||
def get_sample_audio_features(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
sequence_length: int = 3000,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
audio_features = torch.randn(batch_size, config.num_mel_bins, sequence_length, device=device, dtype=torch_dtype)
|
||||
return audio_features
|
||||
|
||||
|
||||
# Create input_ids for decoder
|
||||
# Shape is (batch_size, sequence_length) where sequence_length is the initial decoder sequence length
|
||||
def get_sample_decoder_input_ids(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
use_int32: bool = True,
|
||||
):
|
||||
torch_dtype = torch.int32 if use_int32 else torch.int64
|
||||
decoder_input_ids = torch.randint(
|
||||
low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=device, dtype=torch_dtype
|
||||
)
|
||||
return decoder_input_ids
|
||||
|
||||
|
||||
# Create encoder_hidden_states for decoder-init
|
||||
# Shape is (batch_size, num_frames // 2, hidden_size)
|
||||
def get_sample_encoder_hidden_states(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
encoder_hidden_states = torch.randn(
|
||||
batch_size, config.max_source_positions, config.d_model, device=device, dtype=torch_dtype
|
||||
)
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
# Create past_key_values
|
||||
# Self-attention KV caches are of shape (batch_size, num_heads, past_sequence_length, head_size)
|
||||
# Cross-attention KV caches are of shape (batch_size, num_heads, num_frames // 2, head_size)
|
||||
def get_sample_past_key_values(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
past_seq_len: int,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
num_heads = config.decoder_attention_heads
|
||||
head_size = config.d_model // num_heads
|
||||
max_source_positions = (
|
||||
config.max_source_positions
|
||||
) # equal to num_frames // 2 = encoder's sequence_length // 2 = 3000 // 2 = 1500
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self_attention_kv_caches = [
|
||||
(
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
|
||||
torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
|
||||
)
|
||||
for _ in range(config.decoder_layers)
|
||||
]
|
||||
cross_attention_kv_caches = [
|
||||
(
|
||||
torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
|
||||
torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
|
||||
)
|
||||
for _ in range(config.decoder_layers)
|
||||
]
|
||||
return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches)
|
||||
|
||||
|
||||
# Flatten KV caches into pairs-of-4 where each pair is defined as:
|
||||
# (self_attn_key_cache, self_attn_value_cache, cross_attn_key_cache, cross_attn_value_cache)
|
||||
def flatten_past_key_values(
|
||||
self_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||
cross_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||
):
|
||||
past_key_values = []
|
||||
for (self_k_cache, self_v_cache), (cross_k_cache, cross_v_cache) in zip(
|
||||
self_attn_kv_caches, cross_attn_kv_caches, strict=False
|
||||
):
|
||||
layer_kv_caches = (self_k_cache, self_v_cache, cross_k_cache, cross_v_cache)
|
||||
past_key_values.append(layer_kv_caches)
|
||||
return past_key_values
|
||||
|
||||
|
||||
# Group KV caches into two 1D lists where one list contains the self attention KV caches and
|
||||
# one list contains the cross attention KV caches
|
||||
def group_past_key_values(
|
||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
):
|
||||
self_attn_kv_caches, cross_attn_kv_caches = [], []
|
||||
for self_k_cache, self_v_cache, cross_k_cache, cross_v_cache in kv_caches:
|
||||
self_attn_kv_caches.append(self_k_cache)
|
||||
self_attn_kv_caches.append(self_v_cache)
|
||||
cross_attn_kv_caches.append(cross_k_cache)
|
||||
cross_attn_kv_caches.append(cross_v_cache)
|
||||
return self_attn_kv_caches, cross_attn_kv_caches
|
||||
|
||||
|
||||
# Create alignment heads for timestamps
|
||||
# Shape is (num_alignment_heads, 2)
|
||||
def get_sample_alignment_heads(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
num_alignment_heads: int = 6,
|
||||
use_int32: bool = True,
|
||||
):
|
||||
torch_dtype = torch.int32 if use_int32 else torch.int64
|
||||
alignment_heads = torch.ones((num_alignment_heads, 2), device=device, dtype=torch_dtype)
|
||||
return alignment_heads
|
||||
|
||||
|
||||
# Create length of start-of-transcription sequence for timestamps
|
||||
# Shape is (1)
|
||||
def get_sample_sot_sequence_length(
|
||||
device: torch.device,
|
||||
sot_sequence_length: int,
|
||||
use_int32: bool = False,
|
||||
):
|
||||
torch_dtype = torch.int32 if use_int32 else torch.int64
|
||||
sot_length = torch.tensor([sot_sequence_length], device=device, dtype=torch_dtype)
|
||||
return sot_length
|
||||
|
||||
|
||||
# Create segment length for timestamps
|
||||
# Shape is (1)
|
||||
def get_sample_segment_length(
|
||||
device: torch.device,
|
||||
segment_length: int,
|
||||
use_int32: bool = False,
|
||||
):
|
||||
torch_dtype = torch.int32 if use_int32 else torch.int64
|
||||
segment_size = torch.tensor([segment_length], device=device, dtype=torch_dtype)
|
||||
return segment_size
|
||||
|
||||
|
||||
# Create QKs for timestamps
|
||||
# Shape is (batch_size, num_heads, sequence_length, num_frames // 2)
|
||||
def get_sample_QKs( # noqa: N802
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
num_heads = config.decoder_attention_heads
|
||||
torch_dtype = torch.float16 if use_fp16 else torch.float32
|
||||
QKs = [ # noqa: N806
|
||||
torch.rand(
|
||||
batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype
|
||||
)
|
||||
for _ in range(config.decoder_layers)
|
||||
]
|
||||
return QKs
|
||||
|
||||
|
||||
# Create inputs for encoder component of Whisper
|
||||
def get_sample_encoder_inputs(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
sequence_length: int = 3000,
|
||||
use_fp16: bool = False,
|
||||
):
|
||||
audio_features = get_sample_audio_features(config, device, batch_size, sequence_length, use_fp16)
|
||||
return {"audio_features": audio_features}
|
||||
|
||||
|
||||
# Create inputs for encoder component + first pass through decoder component of Whisper
|
||||
def get_sample_encoder_decoder_init_inputs(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
decoder_sequence_length: int,
|
||||
encoder_sequence_length: int = 3000,
|
||||
use_fp16: bool = False,
|
||||
use_int32: bool = True,
|
||||
):
|
||||
audio_features = get_sample_audio_features(config, device, batch_size, encoder_sequence_length, use_fp16)
|
||||
decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, decoder_sequence_length, use_int32)
|
||||
return {"audio_features": audio_features, "decoder_input_ids": decoder_input_ids}
|
||||
|
||||
|
||||
# Create inputs for decoder component of Whisper
|
||||
# Inputs for first pass through the decoder (i.e. decoder-init): decoder_input_ids, encoder_hidden_states
|
||||
# Inputs for subsequent passes through the decoder (i.e. decoder-with-past): decoder_input_ids, past_key_values
|
||||
def get_sample_decoder_inputs(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
past_sequence_length: int,
|
||||
sequence_length: int,
|
||||
use_fp16: bool = False,
|
||||
use_int32: bool = True,
|
||||
):
|
||||
decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, sequence_length, use_int32)
|
||||
encoder_hidden_states = get_sample_encoder_hidden_states(config, device, batch_size, use_fp16)
|
||||
past_key_values = get_sample_past_key_values(config, device, batch_size, past_sequence_length, use_fp16)
|
||||
return {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
|
||||
|
||||
# Create inputs for timestamps component of Whisper
|
||||
def get_sample_jump_times_inputs(
|
||||
config: WhisperConfig,
|
||||
device: torch.device,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
num_alignment_heads: int,
|
||||
sot_sequence_length: int,
|
||||
segment_length: int,
|
||||
use_fp16: bool = False,
|
||||
use_int32: bool = True,
|
||||
):
|
||||
alignment_heads = get_sample_alignment_heads(config, device, num_alignment_heads, use_int32)
|
||||
# lengths need to be int64 because subsequent 'Slice' ops only take int64 inputs
|
||||
sot_sequence_length = get_sample_sot_sequence_length(device, sot_sequence_length)
|
||||
segment_length = get_sample_segment_length(device, segment_length)
|
||||
QKs = get_sample_QKs(config, device, batch_size, sequence_length, use_fp16) # noqa: N806
|
||||
return {
|
||||
"alignment_heads": alignment_heads,
|
||||
"sot_sequence_length": sot_sequence_length,
|
||||
"segment_length": segment_length,
|
||||
"QKs": QKs,
|
||||
}
|
||||
|
||||
|
||||
# Convert PyTorch inputs to ONNX Runtime inputs
|
||||
def convert_inputs_for_ort(
|
||||
inputs: dict,
|
||||
model: InferenceSession,
|
||||
):
|
||||
self_attn_kv_caches, cross_attn_kv_caches = None, None
|
||||
batch_size, num_heads, past_seq_len, head_size = 0, 0, 0, 0
|
||||
num_beams, max_seq_len = 1, 448
|
||||
if "past_key_values" in inputs:
|
||||
(self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(inputs["past_key_values"])
|
||||
batch_size, num_heads, past_seq_len, head_size = self_attn_kv_caches[0].shape
|
||||
|
||||
ort_inputs = {}
|
||||
model_inputs = list(map(lambda i: i.name, model.get_inputs())) # noqa: C417
|
||||
use_buffer_sharing = "cache_indirection" in model_inputs
|
||||
for name in model_inputs:
|
||||
if name in {"audio_features", "encoder_input_ids"}:
|
||||
# Encoder input
|
||||
ort_inputs[name] = inputs["audio_features"].detach().cpu().numpy()
|
||||
elif name == "encoder_hidden_states":
|
||||
# Encoder output
|
||||
ort_inputs[name] = inputs["encoder_hidden_states"].detach().cpu().numpy()
|
||||
elif name in {"decoder_input_ids", "input_ids"}:
|
||||
# Decoder input
|
||||
ort_inputs[name] = inputs["decoder_input_ids"].detach().cpu().numpy()
|
||||
elif "past_key_self" in name or "past_value_self" in name:
|
||||
# Decoder input
|
||||
orig_kv_cache = self_attn_kv_caches.pop(0).detach().cpu().numpy()
|
||||
if use_buffer_sharing:
|
||||
new_kv_cache = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=orig_kv_cache.dtype)
|
||||
new_kv_cache[:batch_size, :num_heads, :past_seq_len, :head_size] = orig_kv_cache
|
||||
ort_inputs[name] = new_kv_cache
|
||||
else:
|
||||
ort_inputs[name] = orig_kv_cache
|
||||
elif "past_key_cross" in name or "past_value_cross" in name:
|
||||
# Decoder input
|
||||
orig_kv_cache = cross_attn_kv_caches.pop(0).detach().cpu().numpy()
|
||||
ort_inputs[name] = orig_kv_cache
|
||||
elif name == "past_sequence_length":
|
||||
# Decoder input
|
||||
ort_inputs[name] = np.array([past_seq_len], dtype=np.int32)
|
||||
elif name == "cache_indirection":
|
||||
# Decoder input
|
||||
ort_inputs[name] = np.zeros((batch_size, num_beams, max_seq_len), dtype=np.int32)
|
||||
elif name == "alignment_heads":
|
||||
# Jump times input
|
||||
ort_inputs[name] = inputs["alignment_heads"].detach().cpu().numpy()
|
||||
elif name == "sot_sequence_length":
|
||||
# Jump times input
|
||||
ort_inputs[name] = inputs["sot_sequence_length"].detach().cpu().numpy()
|
||||
elif name == "segment_length":
|
||||
# Jump times input
|
||||
ort_inputs[name] = inputs["segment_length"].detach().cpu().numpy()
|
||||
elif "cross_qk" in name:
|
||||
# Jump times input
|
||||
ort_inputs[name] = inputs["QKs"].pop(0).detach().cpu().numpy()
|
||||
else:
|
||||
raise ValueError(f"Unknown name not recognized: {name}")
|
||||
|
||||
return ort_inputs
|
||||
|
||||
|
||||
# Get dynamic axes for all inputs and outputs to the model
|
||||
def get_model_dynamic_axes(
|
||||
config: WhisperConfig,
|
||||
input_names: list[str],
|
||||
output_names: list[str],
|
||||
):
|
||||
dynamic_axes = {}
|
||||
for name in input_names + output_names:
|
||||
if name in {"audio_features", "encoder_input_ids"}:
|
||||
# shape is (batch_size, num_mels, num_frames)
|
||||
dynamic_axes[name] = {0: "batch_size"}
|
||||
elif name in {"input_ids", "decoder_input_ids"}:
|
||||
# shape is (batch_size, sequence_length)
|
||||
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
|
||||
elif name == "alignment_heads":
|
||||
# shape is (num_alignment_heads, 2)
|
||||
dynamic_axes[name] = {0: "num_alignment_heads"}
|
||||
elif name in {"sot_sequence_length", "segment_length"}:
|
||||
# shape is (1)
|
||||
pass
|
||||
elif name == "logits":
|
||||
# shape is (batch_size, sequence_length, vocab_size)
|
||||
dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
|
||||
elif name == "encoder_hidden_states":
|
||||
# shape is (batch_size, num_frames // 2, hidden_size)
|
||||
dynamic_axes[name] = {0: "batch_size"}
|
||||
elif "past_key_self" in name or "past_value_self" in name:
|
||||
# shape is (batch_size, num_heads, past_sequence_length, head_size)
|
||||
dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
|
||||
elif "present_key_self" in name or "present_value_self" in name:
|
||||
# shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size),
|
||||
# which is equal to (batch_size, num_heads, total_sequence_length, head_size)
|
||||
dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
|
||||
elif (
|
||||
"past_key_cross" in name
|
||||
or "past_value_cross" in name
|
||||
or "present_key_cross" in name
|
||||
or "present_value_cross" in name
|
||||
):
|
||||
# shape is (batch_size, num_heads, num_frames // 2, head_size)
|
||||
dynamic_axes[name] = {0: "batch_size"}
|
||||
elif "cross_qk" in name:
|
||||
# shape is (batch_size, num_heads, source_sequence_length, target_sequence_length)
|
||||
dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
|
||||
elif "jump_times" in name:
|
||||
# shape is (batch_size, max_length)
|
||||
dynamic_axes[name] = {0: "batch_size", 1: "max_length"}
|
||||
else:
|
||||
raise Exception(f"Unknown input or output name found: {name}")
|
||||
return dynamic_axes
|
||||
+477
@@ -0,0 +1,477 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.cpp_extension
|
||||
from onnx_model import OnnxModel
|
||||
from transformers import WhisperConfig
|
||||
from whisper_inputs import convert_inputs_for_ort, get_model_dynamic_axes, get_sample_jump_times_inputs
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
from onnxruntime.tools import pytorch_export_contrib_ops
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
##################################################
|
||||
# Functions that have to be outside of the class
|
||||
# for torch.jit.script_if_tracing to work
|
||||
##################################################
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def index_QKs(alignment_heads: torch.Tensor, QKs: list[torch.Tensor]): # noqa: N802
|
||||
"""
|
||||
Compute the following to get stacked QK tensor that has been indexed for the desired attention heads:
|
||||
weights = torch.stack([QKs[_l][:, _h] for _l, _h in alignment_heads], dim=1)
|
||||
"""
|
||||
indexed_QKs = [] # noqa: N806
|
||||
for pair in alignment_heads:
|
||||
# Each QK is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
|
||||
# The `QKs[_l]` selects the right QK from the list of QKs
|
||||
# The `QKs[_l][:, _h]` selects the right attention heads from the chosen QK. The `:` is to do this for the batch dim.
|
||||
#
|
||||
# PyTorch:
|
||||
# QKs[_l] is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
|
||||
# QKs[_l][:, _h] is of shape (batch_size, sequence_length, num_frames // 2)
|
||||
#
|
||||
# ONNX:
|
||||
# QKs[_l] is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
|
||||
# QKs[_l][:, _h] is of shape (batch_size, 1, sequence_length, num_frames // 2) because
|
||||
# the `[:, _h]` operation maps to a Gather op and that op does not reduce dimensions
|
||||
_l, _h = pair[0], pair[1]
|
||||
indexed_QKs.append(QKs[_l][:, _h])
|
||||
|
||||
# PyTorch:
|
||||
# torch.stack will return a tensor of shape (batch_size, num_alignment_heads, sequence_length, num_frames // 2).
|
||||
#
|
||||
# ONNX:
|
||||
# torch.stack will return a tensor of shape (batch_size, num_alignment_heads, 1, sequence_length, num_frames // 2)
|
||||
# because the Gather op does not reduce dimensions. To remove the unneeded dimension, torch.squeeze with a specified
|
||||
# dim (dim = 2) is added. The torch.squeeze op with a specified dim only runs if the specified dim has a size of 1.
|
||||
# Since the dim won't be of size 1 in the PyTorch tensor but it is of size 1 in the ONNX tensor, it will be a no-op
|
||||
# in PyTorch and an op in ONNX. Thus, the Squeeze op will only affect the ONNX model.
|
||||
weights = torch.stack(indexed_QKs, dim=1)
|
||||
weights = torch.squeeze(weights, dim=2)
|
||||
return weights
|
||||
|
||||
|
||||
def jump_timings(text_indices, time_indices):
|
||||
"""
|
||||
Calculate jump times from text_indices and time_indices where
|
||||
text_indices and time_indices are both 1d vectors
|
||||
"""
|
||||
TOKENS_PER_SECOND = 50.0 # noqa: N806
|
||||
diff = text_indices[1:] - text_indices[:-1]
|
||||
padding = torch.tensor([1], dtype=torch.int32)
|
||||
jumps = torch.cat((padding, diff)).to(torch.bool)
|
||||
jump_times = time_indices[jumps].to(torch.float) / TOKENS_PER_SECOND
|
||||
return jump_times
|
||||
|
||||
|
||||
def padded_jump_from_dtw(matrix_2d: torch.Tensor, max_length: torch.Tensor):
|
||||
"""
|
||||
Run Dynamic Time Warping (DTW) on batched tensor
|
||||
"""
|
||||
trace = torch.ops.onnxruntime.DynamicTimeWarping(matrix_2d)
|
||||
text_indices = trace[0, :]
|
||||
time_indices = trace[1, :]
|
||||
jump_times = jump_timings(text_indices, time_indices)
|
||||
return F.pad(jump_times, [0, int((max_length - jump_times.size(-1)).item())], mode="constant", value=-1.0)
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def batch_jump_times(matrix: torch.Tensor, max_decoded_length: torch.Tensor):
|
||||
"""
|
||||
Compute the following to calculate jump times for all batches:
|
||||
batched_jump_times = torch.stack([self.padded_jump_from_dtw(matrix[b], max_decoded_length) for b in range(matrix.size(0))])
|
||||
"""
|
||||
list_of_jump_times = []
|
||||
for b in range(matrix.size(0)):
|
||||
jump_times = padded_jump_from_dtw(matrix[b], max_decoded_length)
|
||||
list_of_jump_times.append(jump_times)
|
||||
batched_jump_times = torch.stack(list_of_jump_times)
|
||||
return batched_jump_times
|
||||
|
||||
|
||||
class WhisperJumpTimes(torch.nn.Module):
|
||||
"""Whisper jump times component"""
|
||||
|
||||
def __init__(self, config: WhisperConfig, device: torch.device, cache_dir: str | os.PathLike):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
self.filter_width = 7
|
||||
self.qk_scale = 1.0
|
||||
|
||||
def median_filter(self, weights: torch.Tensor):
|
||||
"""
|
||||
Apply a median filter of width `filter_width` along the last dimension of `weights`
|
||||
"""
|
||||
pad_width = self.filter_width // 2
|
||||
x = F.pad(weights, (pad_width, pad_width, 0, 0), mode="reflect")
|
||||
x_unfolded = torch.ops.onnxruntime.UnfoldTensor(x, -1, self.filter_width, 1)
|
||||
result = torch.select(x_unfolded.sort()[0], dim=-1, index=pad_width)
|
||||
return result
|
||||
|
||||
def forward(
|
||||
self,
|
||||
alignment_heads: torch.Tensor,
|
||||
sot_sequence_length: torch.Tensor,
|
||||
segment_length: torch.Tensor,
|
||||
QKs: list[torch.Tensor],
|
||||
):
|
||||
# Get stacked QKs tensor
|
||||
weights = index_QKs(alignment_heads, QKs)
|
||||
weights = weights[:, :, : segment_length // 2]
|
||||
weights = weights.to(torch.float32)
|
||||
|
||||
weights = (weights * self.qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = (weights - mean) / std
|
||||
weights = self.median_filter(weights)
|
||||
|
||||
matrix = torch.mean(weights, 1)
|
||||
matrix = -matrix[:, sot_sequence_length:-1]
|
||||
|
||||
max_decoded_length = torch.tensor([matrix.size(1)], dtype=torch.int64)
|
||||
batched_jump_times = batch_jump_times(matrix, max_decoded_length)
|
||||
return batched_jump_times
|
||||
|
||||
def input_names(self):
|
||||
input_names = [
|
||||
"alignment_heads",
|
||||
"sot_sequence_length",
|
||||
"segment_length",
|
||||
*[f"cross_qk_{i}" for i in range(self.config.decoder_layers)],
|
||||
]
|
||||
return input_names
|
||||
|
||||
def output_names(self):
|
||||
output_names = ["jump_times"]
|
||||
return output_names
|
||||
|
||||
def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
|
||||
inputs = get_sample_jump_times_inputs(
|
||||
self.config,
|
||||
self.device,
|
||||
batch_size=2,
|
||||
sequence_length=8,
|
||||
num_alignment_heads=6,
|
||||
sot_sequence_length=3,
|
||||
segment_length=1332,
|
||||
use_fp16=use_fp16_inputs,
|
||||
use_int32=use_int32_inputs,
|
||||
)
|
||||
if return_dict:
|
||||
return inputs
|
||||
return (
|
||||
inputs["alignment_heads"],
|
||||
inputs["sot_sequence_length"],
|
||||
inputs["segment_length"],
|
||||
inputs["QKs"],
|
||||
)
|
||||
|
||||
def create_torch_ops(self):
|
||||
"""
|
||||
1) Create UnfoldTensor and DynamicTimeWarping as torch ops
|
||||
3) Provide a symbolic mapping from torch ops to ORT contrib ops
|
||||
|
||||
See https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html#building-with-jit-compilation
|
||||
for more details on how this works.
|
||||
"""
|
||||
# Set torch extensions directory to cache directory
|
||||
os.environ["TORCH_EXTENSIONS_DIR"] = self.cache_dir
|
||||
|
||||
# Try to import `ninja` pip package
|
||||
try:
|
||||
assert torch.utils.cpp_extension.verify_ninja_availability()
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while verifying `ninja` is available: {e}", exc_info=True) # noqa: G201
|
||||
install_cmd = "pip install ninja"
|
||||
logger.warning(f"Could not import `ninja`. Attempting to install `ninja` via `{install_cmd}`.")
|
||||
os.system(install_cmd)
|
||||
|
||||
# Create UnfoldTensor torch op
|
||||
unfold_op_source = textwrap.dedent("""\
|
||||
#include "torch/script.h"
|
||||
|
||||
torch::Tensor UnfoldTensor(torch::Tensor input, int64_t dim, int64_t size, int64_t step) {
|
||||
return input.unfold(dim, size, step);
|
||||
}
|
||||
|
||||
// namespace is onnxruntime
|
||||
static auto registry = torch::RegisterOperators("onnxruntime::UnfoldTensor", &UnfoldTensor);
|
||||
""")
|
||||
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
name="UnfoldTensor",
|
||||
cpp_sources=unfold_op_source,
|
||||
is_python_module=False,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Create DynamicTimeWarping torch op
|
||||
dtw_op_source = textwrap.dedent("""\
|
||||
#include "torch/script.h"
|
||||
#include "torch/torch.h"
|
||||
#include <stdexcept>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
torch::Tensor Backtrace(torch::Tensor trace) {
|
||||
int64_t i = trace.size(0) - 1;
|
||||
int64_t j = trace.size(1) - 1;
|
||||
trace.index({0, torch::indexing::Slice()}) = 2;
|
||||
trace.index({torch::indexing::Slice(), 0}) = 1;
|
||||
|
||||
std::vector<int32_t> result_vec;
|
||||
while (i > 0 || j > 0) {
|
||||
result_vec.push_back(static_cast<int32_t>(i - 1));
|
||||
result_vec.push_back(static_cast<int32_t>(j - 1));
|
||||
int value = trace[i][j].item<int>();
|
||||
|
||||
if (value == 0) {
|
||||
i--;
|
||||
j--;
|
||||
} else if (value == 1) {
|
||||
i--;
|
||||
} else if (value == 2) {
|
||||
j--;
|
||||
} else {
|
||||
throw std::runtime_error("Unexpected trace[i, j]");
|
||||
}
|
||||
}
|
||||
|
||||
// Compute result[::-1, :].T
|
||||
torch::Tensor result = torch::from_blob(result_vec.data(), {static_cast<long int>(result_vec.size() / 2), 2}, torch::kInt32).clone();
|
||||
torch::Tensor reversed = result.flip(0); // result[::-1, :]
|
||||
torch::Tensor transposed = reversed.transpose(0, 1); // .T
|
||||
return transposed;
|
||||
}
|
||||
|
||||
torch::Tensor DynamicTimeWarping(torch::Tensor x) {
|
||||
int64_t N = x.size(0);
|
||||
int64_t M = x.size(1);
|
||||
torch::Tensor cost = torch::full({N + 1, M + 1}, std::numeric_limits<float>::infinity(), torch::dtype(torch::kFloat32));
|
||||
torch::Tensor trace = torch::full({N + 1, M + 1}, -1, torch::dtype(torch::kFloat32));
|
||||
|
||||
cost[0][0] = 0;
|
||||
for (int j = 1; j < M + 1; j++) {
|
||||
for (int i = 1; i < N + 1; i++) {
|
||||
float c0 = cost[i - 1][j - 1].item<float>();
|
||||
float c1 = cost[i - 1][j].item<float>();
|
||||
float c2 = cost[i][j - 1].item<float>();
|
||||
|
||||
float c = 0;
|
||||
float t = 0;
|
||||
|
||||
if (c0 < c1 && c0 < c2) {
|
||||
c = c0;
|
||||
t = 0;
|
||||
} else if (c1 < c0 && c1 < c2) {
|
||||
c = c1;
|
||||
t = 1;
|
||||
} else {
|
||||
c = c2;
|
||||
t = 2;
|
||||
}
|
||||
|
||||
cost[i][j] = x[i - 1][j - 1].item<float>() + c;
|
||||
trace[i][j] = t;
|
||||
}
|
||||
}
|
||||
|
||||
return Backtrace(trace);
|
||||
}
|
||||
|
||||
// namespace is onnxruntime
|
||||
static auto registry = torch::RegisterOperators("onnxruntime::DynamicTimeWarping", &DynamicTimeWarping);
|
||||
""")
|
||||
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
name="DynamicTimeWarping",
|
||||
cpp_sources=dtw_op_source,
|
||||
is_python_module=False,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Create symbolic mapping from torch ops to ORT contrib ops
|
||||
pytorch_export_contrib_ops.register()
|
||||
|
||||
def export_onnx(
|
||||
self,
|
||||
onnx_model_path: str,
|
||||
provider: str,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_fp16_inputs: bool = False,
|
||||
use_int32_inputs: bool = True,
|
||||
):
|
||||
"""Export word-level timestamps to ONNX
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path to save ONNX model
|
||||
provider (str): provider to use for verifying parity on ONNX model
|
||||
verbose (bool, optional): print verbose information. Defaults to True.
|
||||
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
||||
use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
|
||||
use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
|
||||
"""
|
||||
# Shape of timestamps's tensors:
|
||||
# Inputs:
|
||||
# alignment_heads: (num_alignment_heads, 2)
|
||||
# sot_sequence_length: (1)
|
||||
# segment_length: (1)
|
||||
# cross_qk_*: (batch_size, num_heads, sequence_length, num_frames // 2)
|
||||
# Outputs:
|
||||
# jump_times: (batch_size, max_length)
|
||||
|
||||
# Definitions:
|
||||
# alignment_heads: the attention head indices where the Q*K values are highly correlated with word-level timestamps
|
||||
# (i.e. the alignment between audio and text tokens)
|
||||
# This is calculated as follows:
|
||||
#
|
||||
# ```
|
||||
# import base64
|
||||
# import gzip
|
||||
# import numpy as np
|
||||
# import torch
|
||||
#
|
||||
# # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
# _ALIGNMENT_HEADS = {
|
||||
# "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
# "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
# "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
# "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
# "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
# "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
# "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
# "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
# "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
# "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
# "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
# "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
# "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
# "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
||||
# }
|
||||
#
|
||||
# model_name = "large-v3-turbo"
|
||||
# array = np.frombuffer(
|
||||
# gzip.decompress(base64.b85decode(_ALIGNMENT_HEADS[model_name])), dtype=bool
|
||||
# ).copy()
|
||||
# mask = torch.from_numpy(array).reshape(
|
||||
# self.dims.n_text_layer, self.dims.n_text_head
|
||||
# )
|
||||
# self.alignment_heads = mask.to_sparse().indices().T
|
||||
# ```
|
||||
#
|
||||
# sot_sequence_length: the length of the start-of-transcription sequence before the first token is generated
|
||||
# Typically the start-of-transcription sequence is [<|startoftranscription|>, <|language_token|>, <|task_token|>]
|
||||
# so its length is 3.
|
||||
#
|
||||
# segment_length: the length (in frames) of the audio segment that is being transcribed
|
||||
#
|
||||
# cross_qk_*: the Q*K values for the cross-attention blocks in the decoder
|
||||
# Every decoder layer has a self-attention block and a cross-attention block so there are `n` cross-attention blocks
|
||||
# where `n` is the number of decoder layers.
|
||||
#
|
||||
# jump_times: the timings where jumps occur in speech
|
||||
# This allows us to detect when a word began to be spoken by the speaker (start_times) and when a word was finished
|
||||
# being spoken by the speaker (end_times).
|
||||
|
||||
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
|
||||
input_names = self.input_names()
|
||||
output_names = self.output_names()
|
||||
dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
|
||||
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
|
||||
|
||||
# Create torch ops and map them to ORT contrib ops before export
|
||||
self.create_torch_ops()
|
||||
torch.onnx.export(
|
||||
self,
|
||||
args=inputs,
|
||||
f=out_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=17,
|
||||
do_constant_folding=True,
|
||||
verbose=verbose,
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
)
|
||||
|
||||
if use_external_data_format:
|
||||
model = onnx.load_model(out_path, load_external_data=use_external_data_format)
|
||||
OnnxModel.save(
|
||||
model,
|
||||
onnx_model_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
)
|
||||
|
||||
self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
|
||||
|
||||
def verify_onnx(
|
||||
self,
|
||||
onnx_model_path: str,
|
||||
provider: str,
|
||||
use_fp16_inputs: bool,
|
||||
use_int32_inputs: bool,
|
||||
):
|
||||
"""Verify ONNX model outputs and PyTorch model outputs match
|
||||
|
||||
Args:
|
||||
onnx_model_path (str): path to save ONNX model
|
||||
provider (str): execution provider for ONNX model
|
||||
use_fp16_inputs (bool, optional): use float16 inputs for the cross_qk_{i}
|
||||
use_int32_inputs (bool, optional): use int32 inputs for the alignment_heads and sot_sequence_length
|
||||
"""
|
||||
# Shape of jump times's tensors:
|
||||
# Inputs:
|
||||
# alignment_heads: (num_alignment_heads, 2)
|
||||
# sot_sequence_length: (1)
|
||||
# segment_length: (1)
|
||||
# cross_qk_*: (batch_size, num_heads, sequence_length, num_frames // 2)
|
||||
# Outputs:
|
||||
# jump_times: (batch_size, max_length)
|
||||
inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
|
||||
|
||||
# Run PyTorch model
|
||||
pt_outputs = (
|
||||
self.forward(
|
||||
inputs["alignment_heads"], inputs["sot_sequence_length"], inputs["segment_length"], inputs["QKs"]
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
# Run ONNX model
|
||||
sess = InferenceSession(onnx_model_path, providers=[provider])
|
||||
ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
|
||||
|
||||
# Calculate output difference
|
||||
diff = np.abs(pt_outputs - ort_outputs)
|
||||
print("Comparing batched jump_times...", flush=True)
|
||||
print(f"Max diff: {np.max(diff)}", flush=True)
|
||||
Reference in New Issue
Block a user