Skip to content

Commit 7565342

Browse files
Arm backend: Prevent illegal fusion in FuseEqualPlaceholdersPass (#11781)
- Constant placeholders with same values but different data types, such as int32 and fp32, shouldn't be fused into a single placeholder. Otherwise, some operators will have operands with mismatched dtypes. - Fix the bug by adding a dtype check to fuse only constants with matching types and same values. Signed-off-by: Yufeng Shi <[email protected]>
1 parent 44d2643 commit 7565342

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4949
if tensor2 is None:
5050
continue
5151

52-
if torch.equal(tensor1, tensor2):
52+
if (
53+
tensor1.dtype == tensor2.dtype
54+
and tensor1.shape == tensor2.shape
55+
and torch.allclose(tensor1, tensor2, atol=1e-08)
56+
):
5357
eq_nodes.append(node2)
5458

5559
if len(eq_nodes) > 1:

backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
1111
FuseEqualPlaceholdersPass,
1212
)
13-
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
13+
from executorch.backends.arm.test.tester.test_pipeline import (
14+
PassPipeline,
15+
TosaPipelineMI,
16+
)
1417

1518
input_t = Tuple[torch.Tensor] # Input x
1619

@@ -54,6 +57,25 @@ def forward(self, x):
5457
return self.fc1(x) + self.fc2(x)
5558

5659

60+
class NotFuseTensorWithDifferentType(torch.nn.Module):
61+
62+
ops_before_pass = {}
63+
ops_after_pass = {}
64+
ops_not_after_pass = []
65+
66+
def forward(self, x: torch.Tensor, y: torch.Tensor):
67+
"""
68+
Args:
69+
x: A float tensor (dtype=torch.float32)
70+
y: An int tensor (dtype=torch.int32)
71+
"""
72+
a = torch.tensor(1.0, dtype=torch.float32)
73+
b = torch.tensor(1, dtype=torch.int32)
74+
m = x < a
75+
n = y > b
76+
return m, n
77+
78+
5779
def test_fuse_equal_placeholders_constants_tosa_MI():
5880
module = FuseWeightsConstants()
5981
data = (torch.rand(1, 2, 8),)
@@ -94,3 +116,24 @@ def test_fuse_equal_placeholders_state_dict_tosa_MI():
94116
assert len(state_dict_keys) == 2, "FuseEqualPlaceholders state_dict failed"
95117
assert "_common" in state_dict_keys[0], "FuseEqualPlaceholders state_dict failed"
96118
assert "_common" in state_dict_keys[1], "FuseEqualPlaceholders state_dict failed"
119+
120+
121+
def test_not_fuse_tensor_with_different_type_MI():
122+
module = NotFuseTensorWithDifferentType()
123+
data = (
124+
torch.rand(
125+
1,
126+
),
127+
torch.randint(
128+
0,
129+
10,
130+
(1,),
131+
dtype=torch.int,
132+
),
133+
)
134+
pipeline = TosaPipelineMI[input_t](
135+
module,
136+
data,
137+
aten_op=[],
138+
)
139+
pipeline.run()

0 commit comments

Comments
 (0)