-
Notifications
You must be signed in to change notification settings - Fork 827
Open
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/Issues related to xnnpack delegation and the code under backends/xnnpack/
Description
🐛 Describe the bug
When attempting to lower an AvgPool2d module with a single-element kernel size list, XNN's partitioning logic will error out due to assuming that if kernel_size is a list, it has two elements. In the case where the is contains one element, it should use it for both height and width. The partition does this for scalar values, but not single-element lists.
Repro:
import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import to_edge_transform_and_lower
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d([3])
def forward(self, x):
return self.pool(x)
inputs = (
torch.randn(1, 3, 16, 16),
)
ep = torch.export.export(Model(), inputs)
print(ep)
et_program = to_edge_transform_and_lower(
ep,
partitioner=[XnnpackPartitioner()]
).to_executorch()Output:
File ~/miniconda3/envs/pytorch/lib/python3.11/site-packages/executorch/backends/xnnpack/partition/config/generic_node_configs.py:301, in MaxPool2dConfig.check_constraints(self, node, ep)
298 why(node, reason="ceil mode is not supported for dynamic shapes")
299 return False
--> 301 if stride[0] > kernel_size[0] or stride[1] > kernel_size[1]: # pyre-ignore[16]
302 why(
303 node,
304 reason=f"stride ({stride}) must be less than or equal to kernel size ({kernel_size})",
305 )
306 return False
IndexError: list index out of range
Versions
N/A
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
backend testerThis bug was found by the backend test suite.This bug was found by the backend test suite.module: xnnpackIssues related to xnnpack delegation and the code under backends/xnnpack/Issues related to xnnpack delegation and the code under backends/xnnpack/