switching to high quality piper tts and added label translations
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
from .fusion import Fusion # noqa: F401
|
||||
from .fusion_gelu import FusionGelu # noqa: F401
|
||||
from .fusion_layernorm import FusionLayerNormalization # noqa: F401
|
||||
from .replace_upsample_with_resize import ReplaceUpsampleWithResize # noqa: F401
|
||||
@@ -0,0 +1,311 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
import onnx
|
||||
|
||||
from ..onnx_model import ONNXModel
|
||||
|
||||
|
||||
class Fusion:
|
||||
"""
|
||||
Base class for fusions.
|
||||
"""
|
||||
|
||||
def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
|
||||
self.search_op_type: str = search_op_type
|
||||
self.fused_op_type: str = fused_op_type
|
||||
self.model: ONNXModel = model
|
||||
self.nodes_to_remove: list = []
|
||||
self.nodes_to_add: list = []
|
||||
|
||||
self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
|
||||
self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
):
|
||||
"""
|
||||
Interface function for derived fusion classes. Tries to fuse a node sequence containing
|
||||
the specified node.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(self) -> bool:
|
||||
"""
|
||||
Apply graph fusion on the entire model graph.
|
||||
"""
|
||||
input_name_to_nodes = self.model.input_name_to_nodes()
|
||||
output_name_to_node = self.model.output_name_to_node()
|
||||
|
||||
for node in self.model.nodes():
|
||||
if node.op_type == self.search_op_type:
|
||||
self.fuse(node, input_name_to_nodes, output_name_to_node)
|
||||
|
||||
self.model.remove_nodes(self.nodes_to_remove)
|
||||
self.model.add_nodes(self.nodes_to_add)
|
||||
|
||||
graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
|
||||
|
||||
if graph_updated:
|
||||
self.model.remove_unused_constant()
|
||||
|
||||
return graph_updated
|
||||
|
||||
def create_unique_node_name(self):
|
||||
prefix = self._new_node_name_prefix
|
||||
|
||||
if self._new_node_name_suffix is None:
|
||||
largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
|
||||
self._new_node_name_suffix = largest_suffix + 1
|
||||
|
||||
new_name = f"{prefix}{self._new_node_name_suffix!s}"
|
||||
self._new_node_name_suffix += 1
|
||||
|
||||
return new_name
|
||||
|
||||
@staticmethod
|
||||
def is_safe_to_fuse_nodes(
|
||||
nodes_to_remove: list[onnx.NodeProto],
|
||||
keep_outputs: list[str],
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> bool:
|
||||
for node_to_remove in nodes_to_remove:
|
||||
for output_to_remove in node_to_remove.output:
|
||||
if output_to_remove in keep_outputs:
|
||||
continue
|
||||
|
||||
if output_to_remove in input_name_to_nodes:
|
||||
for impacted_node in input_name_to_nodes[output_to_remove]:
|
||||
if impacted_node not in nodes_to_remove:
|
||||
# Not safe to remove nodes since output is used by impacted_node
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
|
||||
for attr in node.attribute:
|
||||
if attr.name == attribute_name:
|
||||
value = onnx.helper.get_attribute_value(attr)
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
|
||||
for index, input_name in enumerate(child_node.input):
|
||||
if input_name == node_output:
|
||||
return index
|
||||
return -1
|
||||
|
||||
@staticmethod
|
||||
def tensor_shape_to_list(tensor_type) -> list[int]:
|
||||
shape_list = []
|
||||
for d in tensor_type.shape.dim:
|
||||
if d.HasField("dim_value"):
|
||||
shape_list.append(d.dim_value) # known dimension
|
||||
elif d.HasField("dim_param"):
|
||||
shape_list.append(d.dim_param) # unknown dimension with symbolic name
|
||||
else:
|
||||
shape_list.append("?") # shall not happen
|
||||
return shape_list
|
||||
|
||||
def get_constant_input(self, node: onnx.NodeProto):
|
||||
for i, inp in enumerate(node.input):
|
||||
value = self.model.get_constant_value(inp)
|
||||
if value is not None:
|
||||
return i, value
|
||||
|
||||
return None, None
|
||||
|
||||
def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
|
||||
i, value = self.get_constant_input(node)
|
||||
if value is not None and value.size == 1 and abs(value - expected_value) < delta:
|
||||
return i
|
||||
|
||||
return -1
|
||||
|
||||
def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
|
||||
return self.find_constant_input(node, expected_value, delta) >= 0
|
||||
|
||||
def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
|
||||
value = self.model.get_constant_value(output_name)
|
||||
if value is None:
|
||||
return False # Not an initializer
|
||||
|
||||
if len(value.shape) != rank:
|
||||
return False # Wrong dimensions
|
||||
|
||||
return True
|
||||
|
||||
def match_first_parent(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
parent_op_type: str,
|
||||
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
||||
exclude: list[onnx.NodeProto] = [], # noqa: B006
|
||||
) -> tuple[onnx.NodeProto | None, int | None]:
|
||||
"""
|
||||
Find parent node based on constraints on op_type.
|
||||
|
||||
Args:
|
||||
node: current node.
|
||||
parent_op_type (str): constraint of parent node op_type.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
||||
|
||||
Returns:
|
||||
parent: The matched parent node. None if not found.
|
||||
index: The input index of matched parent node. None if not found.
|
||||
"""
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.model.output_name_to_node()
|
||||
|
||||
for i, inp in enumerate(node.input):
|
||||
if inp in output_name_to_node:
|
||||
parent = output_name_to_node[inp]
|
||||
if parent.op_type == parent_op_type and parent not in exclude:
|
||||
return parent, i
|
||||
|
||||
return None, None
|
||||
|
||||
def match_parent(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
parent_op_type: str,
|
||||
input_index: int | None = None,
|
||||
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
||||
exclude: list[onnx.NodeProto] = [], # noqa: B006
|
||||
return_indice: list[int] | None = None,
|
||||
) -> onnx.NodeProto | None:
|
||||
"""
|
||||
Find parent node based on constraints on op_type and index.
|
||||
When input_index is None, we will find the first parent node based on constraints,
|
||||
and return_indice will be appended the corresponding input index.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
parent_op_type (str): constraint of parent node op_type.
|
||||
input_index (int or None): only check the parent given input index of current node.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
||||
return_indice (list): a list to append the input index when input_index is None.
|
||||
|
||||
Returns:
|
||||
parent: The matched parent node.
|
||||
"""
|
||||
assert node is not None
|
||||
assert input_index is None or input_index >= 0
|
||||
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.model.output_name_to_node()
|
||||
|
||||
if input_index is None:
|
||||
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
|
||||
if return_indice is not None:
|
||||
return_indice.append(index)
|
||||
return parent
|
||||
|
||||
if input_index >= len(node.input):
|
||||
# Input index out of bounds.
|
||||
return None
|
||||
|
||||
parent = self.model.get_parent(node, input_index, output_name_to_node)
|
||||
if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
|
||||
return parent
|
||||
|
||||
return None
|
||||
|
||||
def match_parent_path(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
parent_op_types: list[str],
|
||||
parent_input_index: list[int] | None = None,
|
||||
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
||||
return_indice: list[int] | None = None,
|
||||
) -> list[onnx.NodeProto] | None:
|
||||
"""
|
||||
Find a sequence of input edges based on constraints on parent op_type and index.
|
||||
When input_index is None, we will find the first parent node based on constraints,
|
||||
and return_indice will be appended the corresponding input index.
|
||||
|
||||
Args:
|
||||
node (str): current node name.
|
||||
parent_op_types (str): constraint of parent node op_type of each input edge.
|
||||
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
|
||||
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
||||
return_indice (list): a list to append the input index
|
||||
When there is no constraint on input index of an edge.
|
||||
|
||||
Returns:
|
||||
parents: a list of matched parent node.
|
||||
"""
|
||||
if parent_input_index is not None:
|
||||
assert len(parent_input_index) == len(parent_op_types)
|
||||
|
||||
if output_name_to_node is None:
|
||||
output_name_to_node = self.model.output_name_to_node()
|
||||
|
||||
current_node = node
|
||||
matched_parents = []
|
||||
for i, op_type in enumerate(parent_op_types):
|
||||
matched_parent = self.match_parent(
|
||||
current_node,
|
||||
op_type,
|
||||
parent_input_index[i] if parent_input_index is not None else None,
|
||||
output_name_to_node,
|
||||
exclude=[],
|
||||
return_indice=return_indice,
|
||||
)
|
||||
if matched_parent is None:
|
||||
return None
|
||||
|
||||
matched_parents.append(matched_parent)
|
||||
current_node = matched_parent
|
||||
|
||||
return matched_parents
|
||||
|
||||
def match_parent_paths(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
paths: list[tuple[list[str], list[int]]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
|
||||
"""
|
||||
Find a matching parent path to the given node.
|
||||
"""
|
||||
for i, path in enumerate(paths):
|
||||
return_indice = []
|
||||
matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
|
||||
if matched:
|
||||
return i, matched, return_indice
|
||||
return -1, None, None
|
||||
|
||||
def find_first_child_by_type(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
child_type: str,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
|
||||
recursive: bool = True,
|
||||
) -> onnx.NodeProto | None:
|
||||
children = self.model.get_children(node, input_name_to_nodes)
|
||||
dq = deque(children)
|
||||
while len(dq) > 0:
|
||||
current_node = dq.pop()
|
||||
if current_node.op_type == child_type:
|
||||
return current_node
|
||||
|
||||
if recursive:
|
||||
children = self.model.get_children(current_node, input_name_to_nodes)
|
||||
for child in children:
|
||||
dq.appendleft(child)
|
||||
|
||||
return None
|
||||
+272
@@ -0,0 +1,272 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx
|
||||
|
||||
from ..onnx_model import ONNXModel
|
||||
from .fusion import Fusion
|
||||
|
||||
|
||||
class FusionGelu(Fusion):
|
||||
def __init__(self, model: ONNXModel):
|
||||
super().__init__(model, "Gelu", "Erf")
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
erf_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
):
|
||||
"""
|
||||
Interface function that tries to fuse a node sequence containing an Erf node into a single
|
||||
Gelu node.
|
||||
"""
|
||||
if (
|
||||
self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
|
||||
or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
|
||||
or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
|
||||
):
|
||||
self.model.set_opset_import("com.microsoft", 1)
|
||||
|
||||
def fuse_1(
|
||||
self,
|
||||
erf_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> bool:
|
||||
"""
|
||||
This pattern is from PyTorch model
|
||||
Fuse Gelu with Erf into one node:
|
||||
Pattern 1:
|
||||
+-------Mul(0.5)---------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Div -----> Erf --> Add --> Mul -->
|
||||
(B=1.4142...) (1)
|
||||
|
||||
Pattern 2:
|
||||
+------------------------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
|
||||
(B=1.4142...) (1) (0.5)
|
||||
|
||||
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
||||
"""
|
||||
if erf_node.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[erf_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Add":
|
||||
return False
|
||||
add_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(add_after_erf, 1):
|
||||
return False
|
||||
|
||||
if add_after_erf.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
|
||||
children = input_name_to_nodes[add_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
|
||||
mul_after_erf = children[0]
|
||||
|
||||
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
|
||||
if div is None:
|
||||
return False
|
||||
|
||||
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
|
||||
return False
|
||||
|
||||
subgraph_input = div.input[0]
|
||||
|
||||
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
|
||||
if subgraph_input == mul_after_erf.input[another]: # pattern 2
|
||||
children = input_name_to_nodes[mul_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
mul_half = children[0]
|
||||
if not self.has_constant_input(mul_half, 0.5):
|
||||
return False
|
||||
subgraph_output = mul_half.output[0]
|
||||
else: # pattern 1
|
||||
mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
|
||||
if mul_half is None:
|
||||
return False
|
||||
|
||||
if not self.has_constant_input(mul_half, 0.5):
|
||||
return False
|
||||
|
||||
if subgraph_input not in mul_half.input:
|
||||
return False
|
||||
|
||||
subgraph_output = mul_after_erf.output[0]
|
||||
|
||||
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
|
||||
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
|
||||
return False
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
fused_node = onnx.helper.make_node(
|
||||
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
|
||||
)
|
||||
fused_node.domain = "com.microsoft"
|
||||
self.nodes_to_add.append(fused_node)
|
||||
return True
|
||||
|
||||
def fuse_2(
|
||||
self,
|
||||
erf_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> bool:
|
||||
"""
|
||||
This pattern is from Keras model
|
||||
Fuse Gelu with Erf into one node:
|
||||
+------------------------------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Div -----> Erf --> Add --> Mul -->Mul
|
||||
(B=1.4142...) (A=1) (A=0.5)
|
||||
|
||||
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
||||
"""
|
||||
if erf_node.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[erf_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Add":
|
||||
return False
|
||||
add_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(add_after_erf, 1):
|
||||
return False
|
||||
|
||||
if add_after_erf.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[add_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
mul_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(mul_after_erf, 0.5):
|
||||
return False
|
||||
|
||||
if mul_after_erf.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[mul_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
mul = children[0]
|
||||
|
||||
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
|
||||
if div is None:
|
||||
return False
|
||||
|
||||
sqrt_node = None
|
||||
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
|
||||
sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
|
||||
if sqrt_node is None:
|
||||
return False
|
||||
if not self.has_constant_input(sqrt_node, 2.0):
|
||||
return False
|
||||
|
||||
subgraph_input = div.input[0]
|
||||
|
||||
if subgraph_input not in mul.input:
|
||||
return False
|
||||
|
||||
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
|
||||
if sqrt_node:
|
||||
subgraph_nodes.append(sqrt_node)
|
||||
|
||||
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
|
||||
return False
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
fused_node = onnx.helper.make_node(
|
||||
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
|
||||
)
|
||||
fused_node.domain = "com.microsoft"
|
||||
self.nodes_to_add.append(fused_node)
|
||||
return True
|
||||
|
||||
def fuse_3(
|
||||
self,
|
||||
erf_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
) -> bool:
|
||||
"""
|
||||
This pattern is from TensorFlow model
|
||||
Fuse Gelu with Erf into one node:
|
||||
+----------------------------------------------+
|
||||
| |
|
||||
| v
|
||||
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
|
||||
(A=0.7071067690849304) (B=1) (B=0.5)
|
||||
|
||||
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
||||
"""
|
||||
|
||||
if erf_node.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[erf_node.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Add":
|
||||
return False
|
||||
add_after_erf = children[0]
|
||||
|
||||
if not self.has_constant_input(add_after_erf, 1):
|
||||
return False
|
||||
|
||||
if add_after_erf.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[add_after_erf.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
mul_half = children[0]
|
||||
|
||||
if not self.has_constant_input(mul_half, 0.5):
|
||||
return False
|
||||
|
||||
first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
|
||||
if first_mul is None:
|
||||
return False
|
||||
|
||||
i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
|
||||
if i < 0:
|
||||
return False
|
||||
|
||||
root_input_index = 1 - i
|
||||
subgraph_input = first_mul.input[root_input_index]
|
||||
|
||||
if mul_half.output[0] not in input_name_to_nodes:
|
||||
return False
|
||||
children = input_name_to_nodes[mul_half.output[0]]
|
||||
if len(children) != 1 or children[0].op_type != "Mul":
|
||||
return False
|
||||
last_mul = children[0]
|
||||
|
||||
if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
|
||||
return False
|
||||
|
||||
subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
|
||||
if not self.is_safe_to_fuse_nodes(
|
||||
subgraph_nodes,
|
||||
[last_mul.output[0]],
|
||||
input_name_to_nodes,
|
||||
output_name_to_node,
|
||||
):
|
||||
return False
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
fused_node = onnx.helper.make_node(
|
||||
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
|
||||
)
|
||||
fused_node.domain = "com.microsoft"
|
||||
self.nodes_to_add.append(fused_node)
|
||||
return True
|
||||
+135
@@ -0,0 +1,135 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx
|
||||
|
||||
from ..onnx_model import ONNXModel
|
||||
from .fusion import Fusion
|
||||
|
||||
|
||||
class FusionLayerNormalization(Fusion):
|
||||
def __init__(self, model: ONNXModel):
|
||||
super().__init__(model, "LayerNormalization", "ReduceMean")
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
reduce_mean_node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
):
|
||||
"""
|
||||
Interface function that tries to fuse a node sequence containing a ReduceMean node into a single
|
||||
LayerNormalization node.
|
||||
|
||||
+----------------------+
|
||||
| |
|
||||
| v
|
||||
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
||||
(axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
|
||||
| |
|
||||
+-------------------------------------------------+
|
||||
|
||||
It also handles cases of duplicated sub nodes exported from older version of PyTorch:
|
||||
|
||||
+----------------------+
|
||||
| v
|
||||
| +-------> Sub-----------------------------------------------+
|
||||
| | |
|
||||
| | v
|
||||
[Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
|
||||
| ^
|
||||
| |
|
||||
+----------------------+
|
||||
"""
|
||||
children = self.model.get_children(reduce_mean_node, input_name_to_nodes)
|
||||
if len(children) == 0 or len(children) > 2:
|
||||
return
|
||||
|
||||
root_input = reduce_mean_node.input[0]
|
||||
|
||||
if children[0].op_type != "Sub" or children[0].input[0] != root_input:
|
||||
return
|
||||
|
||||
if len(children) == 2:
|
||||
if children[1].op_type != "Sub" or children[1].input[0] != root_input:
|
||||
return
|
||||
|
||||
div_node = None
|
||||
for child in children:
|
||||
div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
|
||||
if div_node is not None:
|
||||
break
|
||||
if div_node is None:
|
||||
return
|
||||
|
||||
path_id, parent_nodes, _ = self.match_parent_paths(
|
||||
div_node,
|
||||
[
|
||||
(["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
|
||||
(
|
||||
["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"],
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
),
|
||||
],
|
||||
output_name_to_node,
|
||||
)
|
||||
if path_id < 0:
|
||||
return
|
||||
|
||||
sub_node = parent_nodes[-1]
|
||||
if sub_node not in children:
|
||||
return
|
||||
|
||||
second_add_node = parent_nodes[1]
|
||||
i, add_weight = self.get_constant_input(second_add_node)
|
||||
if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
|
||||
# Skip fusion since epsilon value is not expected.
|
||||
return
|
||||
|
||||
pow_node = parent_nodes[3]
|
||||
if self.find_constant_input(pow_node, 2.0) != 1:
|
||||
return
|
||||
|
||||
mul_node = input_name_to_nodes[div_node.output[0]][0]
|
||||
if mul_node.op_type != "Mul":
|
||||
return
|
||||
|
||||
last_add_node = input_name_to_nodes[mul_node.output[0]][0]
|
||||
if last_add_node.op_type != "Add":
|
||||
return
|
||||
|
||||
subgraph_nodes = [reduce_mean_node]
|
||||
subgraph_nodes.extend(children)
|
||||
subgraph_nodes.extend(parent_nodes[:-1])
|
||||
|
||||
subgraph_nodes.extend([last_add_node, mul_node, div_node])
|
||||
if not self.is_safe_to_fuse_nodes(
|
||||
subgraph_nodes,
|
||||
last_add_node.output,
|
||||
input_name_to_nodes,
|
||||
output_name_to_node,
|
||||
):
|
||||
return
|
||||
|
||||
weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)]
|
||||
if not self.is_constant_with_specified_rank(weight_input, 1):
|
||||
return
|
||||
|
||||
bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)]
|
||||
if not self.is_constant_with_specified_rank(bias_input, 1):
|
||||
return
|
||||
|
||||
self.nodes_to_remove.extend(subgraph_nodes)
|
||||
|
||||
normalize_node = onnx.helper.make_node(
|
||||
"LayerNormalization",
|
||||
name=self.create_unique_node_name(),
|
||||
inputs=[reduce_mean_node.input[0], weight_input, bias_input],
|
||||
outputs=[last_add_node.output[0]],
|
||||
)
|
||||
normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))])
|
||||
self.nodes_to_add.append(normalize_node)
|
||||
+96
@@ -0,0 +1,96 @@
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
|
||||
from ..onnx_model import ONNXModel
|
||||
from .fusion import Fusion
|
||||
|
||||
|
||||
class ReplaceUpsampleWithResize(Fusion):
|
||||
"""Replace Upsample with Resize."""
|
||||
|
||||
def __init__(self, model: ONNXModel, opset):
|
||||
"""Initialize."""
|
||||
super().__init__(model, "Resize", "Upsample")
|
||||
self.opset = opset
|
||||
|
||||
def fuse(
|
||||
self,
|
||||
node: onnx.NodeProto,
|
||||
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
||||
output_name_to_node: dict[str, onnx.NodeProto],
|
||||
):
|
||||
"""Replace Upsample with Resize."""
|
||||
mode = None
|
||||
for attr in node.attribute:
|
||||
if attr.name == "mode":
|
||||
mode = attr.s.decode("utf-8")
|
||||
break
|
||||
|
||||
scales_input = None
|
||||
if self.opset > 7:
|
||||
scales_input = node.input[1] if len(node.input) > 1 else ""
|
||||
resize_inputs = [node.input[0], node.name + "_roi", scales_input]
|
||||
else:
|
||||
if self.opset == 7:
|
||||
for attr in node.attribute:
|
||||
if attr.name == "scales":
|
||||
scales_input = attr.floats
|
||||
break
|
||||
|
||||
scales_input = np.array(list(scales_input), np.float32)
|
||||
else:
|
||||
h_scale = 1
|
||||
w_scale = 1
|
||||
for attr in node.attribute:
|
||||
if attr.name == "height_scale":
|
||||
h_scale = attr.float
|
||||
elif attr.name == "width_scale":
|
||||
w_scale = attr.float
|
||||
|
||||
scales_input = np.array([1, 1, h_scale, w_scale], np.float32)
|
||||
|
||||
scales_tensor = onnx.helper.make_tensor(
|
||||
name=node.name + "_scales",
|
||||
data_type=onnx.TensorProto.FLOAT,
|
||||
dims=scales_input.shape,
|
||||
vals=scales_input.flatten().tolist(),
|
||||
)
|
||||
|
||||
scales_node = onnx.helper.make_node(
|
||||
"Constant", inputs=[], outputs=[node.name + "_scales"], value=scales_tensor
|
||||
)
|
||||
|
||||
self.nodes_to_add.append(scales_node)
|
||||
|
||||
resize_inputs = [node.input[0], node.name + "_roi", node.name + "_scales"]
|
||||
|
||||
roi_tensor = onnx.helper.make_tensor(
|
||||
name=node.name + "_roi",
|
||||
data_type=onnx.TensorProto.FLOAT,
|
||||
dims=(len(scales_input) * 2,),
|
||||
vals=[0] * len(scales_input) + [1] * len(scales_input),
|
||||
)
|
||||
|
||||
roi_node = onnx.helper.make_node("Constant", inputs=[], outputs=[node.name + "_roi"], value=roi_tensor)
|
||||
|
||||
resize_node = onnx.helper.make_node(
|
||||
op_type="Resize", inputs=resize_inputs, outputs=node.output, mode=mode, nearest_mode="floor"
|
||||
)
|
||||
|
||||
self.nodes_to_remove.append(node)
|
||||
self.nodes_to_add.append(roi_node)
|
||||
self.nodes_to_add.append(resize_node)
|
||||
|
||||
def apply(self) -> bool:
|
||||
"""Apply."""
|
||||
if super().apply():
|
||||
self.model.topological_sort()
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user