Skip to content

Fix bug in sub op to ignore alpha != 1 (fixes: #11684) #11796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

abhinaykukkadapu
Copy link
Contributor

Summary

Sub node of XNNPack backend doesn't consider alpha value, this fixes the bug by falling back to portable ops and avoid partitioning the node.

(fixes: #11684)

Test plan

$ python -m unittest backends/xnnpack/test/ops/test_sub.py

Ran 3 tests in 7.262s

OK

Model run

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer

class Model(torch.nn.Module):
    def forward(self, x, y):
        return torch.sub(x, y, alpha=10)

inputs = (
    torch.randn(10),
    torch.randn(10),
)
model = Model()

ep = torch.export.export(model, inputs)
lowered = to_edge_transform_and_lower(
    ep,
    partitioner=[XnnpackPartitioner()],
).to_executorch()

et_model = _load_for_executorch_from_buffer(lowered.buffer)

eager_output = model(*inputs)
et_output = et_model([*inputs])[0]

tolerance=1e-5
if torch.allclose(eager_output, et_output, atol=tolerance):
    print("Outputs are within the tolerance level.")
else:
    print("Outputs differ beyond the tolerance level.")

output

Outputs are within the tolerance level.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 18, 2025
Copy link

pytorch-bot bot commented Jun 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/11796

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Cancelled Jobs, 2 Unrelated Failures

As of commit f020907 with merge base daebcde (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@abhinaykukkadapu abhinaykukkadapu force-pushed the fix_bug_sub_ignore_alpha branch from d5e02ac to 452c96e Compare June 18, 2025 19:21
@abhinaykukkadapu abhinaykukkadapu marked this pull request as ready for review June 18, 2025 19:21
@abhinaykukkadapu abhinaykukkadapu changed the title Fix bug add ignore alpha (fixes: #11684) Fix bug in sub op to ignore alpha != 1 (fixes: #11684) Jun 18, 2025
@manuelcandales manuelcandales added the release notes: xnnpack Changes to the XNNPack backend delegate label Jun 18, 2025
@abhinaykukkadapu abhinaykukkadapu force-pushed the fix_bug_sub_ignore_alpha branch from 452c96e to f020907 Compare June 18, 2025 23:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: xnnpack Changes to the XNNPack backend delegate
Projects
None yet
Development

Successfully merging this pull request may close these issues.

XNNPACK sub ignores alpha value
5 participants