Skip to content

XNNPACK sub ignores alpha value #11684

@GregoryComer

Description

@GregoryComer

🐛 Describe the bug

The XNNPACK delegate ignores alpha value when delegating add ops. This causes it to generate incorrect output. We should add a partitioner constraint to not partition add with alpha != 1. Or alternatively, we can decompose into a multiply and sub. Given that this is rare, I'd probably just add a partitioner constraint.

To fix:

Repro:

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer

class Model(torch.nn.Module):
    def forward(self, x, y):
        return torch.sub(x, y, alpha=10)

inputs = (
    torch.randn(10),
    torch.randn(10),
)
model = Model()

ep = torch.export.export(model, inputs)
lowered = to_edge_transform_and_lower(
    ep,
    partitioner=[XnnpackPartitioner()],
).to_executorch()

et_model = _load_for_executorch_from_buffer(lowered.buffer)

eager_output = model(*inputs)
et_output = et_model([*inputs])[0]

print(f"Eager: {eager_output}")
print(f"ET:    {et_output}")
print(f"Error: {et_output-eager_output}")

Output:

Eager: tensor([  3.5325, -14.6426,  11.7786,  15.1562,  28.2268,  -8.4315,  -6.4738,
         -5.6904,  10.8814,  22.7102])
ET:    tensor([ 0.7991, -2.1287,  1.0652,  1.6636,  2.4694, -2.6314, -0.7396, -0.1184,
         0.1409,  1.0677])
Error: tensor([ -2.7334,  12.5139, -10.7135, -13.4926, -25.7574,   5.8001,   5.7342,
          5.5720, -10.7405, -21.6425])

Versions

N/A

cc @digantdesai @mcr229 @cbilgin

Metadata

Metadata

Labels

backend testerThis bug was found by the backend test suite.good first issueGood for newcomersmodule: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/

Type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions