Skip to content

Commit 84d4a76

Browse files
committed
support channels last dim order in xnnpack
1 parent d9503e6 commit 84d4a76

File tree

5 files changed

+170
-21
lines changed

5 files changed

+170
-21
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ class ChannelsLastTaggedReshapePass(XNNPACKPass):
5656

5757
# Set of ops that require memory format to be NCHW
5858
memory_sensitive_ops_nchw = {
59-
"output",
6059
exir_ops.edge.aten.squeeze_copy.dim,
6160
exir_ops.edge.aten.unsqueeze_copy.default,
61+
exir_ops.edge.aten.linear.default,
6262
}
6363

6464
# Tag which is added to a node's meta to indicate that it uses NHWC format.
@@ -91,10 +91,18 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9191
return not self.is_nhwc_node(node)
9292

9393
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94-
return node.target in self.memory_sensitive_ops_nhwc
94+
return (
95+
node.target in self.memory_sensitive_ops_nhwc
96+
or node.name == "output"
97+
and not node.args[0][0].meta["val"].is_contiguous()
98+
)
9599

96100
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
97-
return node.target in self.memory_sensitive_ops_nchw
101+
return (
102+
node.target in self.memory_sensitive_ops_nchw
103+
or node.name == "output"
104+
and node.args[0][0].meta["val"].is_contiguous()
105+
)
98106

99107
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
100108
# There are two conditions that must be met for a node to be able to
@@ -269,7 +277,10 @@ def input_to_nhwc(
269277
# serializing graph, but don't do anything else here
270278
self.mark_as_nhwc_node(input_node)
271279

272-
if self.is_nhwc_node(input_node):
280+
if input_node.name == "x":
281+
if not input_node.meta["val"][0].is_contiguous():
282+
return
283+
elif self.is_nhwc_node(input_node):
273284
return
274285

275286
if not self.can_be_converted_to_nhwc(input_node):
@@ -333,7 +344,10 @@ def input_to_nchw(
333344
# do anything else here
334345
self.mark_as_nchw_node(input_node)
335346

336-
if self.is_nchw_node(input_node):
347+
if input_node.name == "x":
348+
if input_node.meta["val"].is_contiguous():
349+
return
350+
elif self.is_nchw_node(input_node):
337351
return
338352

339353
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
@@ -371,7 +385,11 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
371385
# first input to be nhwc. This makes this node's output nhwc too
372386
# Currently, all nodes like this should have all of their other
373387
# inputs as nchw, so fail if this is not true
374-
self.input_to_nhwc(graph_module, node.args[0], node)
388+
if node.name == "output":
389+
self.input_to_nhwc(graph_module, node.args[0][0], node)
390+
else:
391+
self.input_to_nhwc(graph_module, node.args[0], node)
392+
375393
for input_node in node.all_input_nodes[1:]:
376394
if self.is_nhwc_node(input_node):
377395
raise AssertionError(

backends/xnnpack/runtime/XNNExecutor.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,16 @@ ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
106106
err == Error::Ok,
107107
Internal,
108108
"Failed to retrieve dim order from tensor!");
109-
ET_CHECK_OR_RETURN_ERROR(
110-
is_contiguous_dim_order(dim_order, tensor->dim()),
111-
Internal,
112-
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
113-
i);
114109
size_t dims[XNN_MAX_TENSOR_DIMS];
115110
ET_CHECK_OR_RETURN_ERROR(
116111
num_dims <= XNN_MAX_TENSOR_DIMS,
117112
InvalidArgument,
118113
"XNNPACK backend accepts tensors with at most %d dims, but got %zu",
119114
XNN_MAX_TENSOR_DIMS,
120115
num_dims);
121-
for (int d = 0; d < num_dims; ++d) {
122-
dims[d] = tensor->size(d);
116+
117+
for(int i = 0; i < num_dims; ++i){
118+
dims[i] = tensor->size(static_cast<int>(dim_order[i]));
123119
}
124120
status =
125121
xnn_reshape_external_value(runtime_.get(), ext_id, num_dims, dims);
@@ -220,8 +216,16 @@ ET_NODISCARD Error XNNExecutor::resize_outputs(EValue** args) const {
220216

221217
// Convert new output shape into SizesType
222218
SizesType expected_output_size[kTensorDimensionLimit];
223-
for (size_t d = 0; d < num_dim; ++d) {
224-
expected_output_size[d] = static_cast<SizesType>(dims[d]);
219+
executorch::aten::DimOrderType dim_order[kTensorDimensionLimit];
220+
Error errr =
221+
ET_RUNTIME_NAMESPACE::get_dim_order(*out_tensor, dim_order, num_dim);
222+
ET_CHECK_OR_RETURN_ERROR(
223+
errr == Error::Ok,
224+
Internal,
225+
"Failed to retrieve dim order from tensor!");
226+
227+
for(int i = 0; i < num_dim; ++i){
228+
expected_output_size[static_cast<int>(dim_order[i])] = static_cast<SizesType>(dims[i]);
225229
}
226230

227231
executorch::aten::ArrayRef<SizesType> output_size{

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,50 @@ def setUp(self):
4242
"executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
4343
)
4444
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
45+
def run_tester(self, module, inputs):
46+
tester = Tester(
47+
module.eval(),
48+
inputs,
49+
)
50+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
51+
52+
class LinearConv(torch.nn.Module):
53+
def __init__(self):
54+
super().__init__()
55+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
56+
self.linear1 = torch.nn.Linear(4, 3)
57+
def forward(self, x):
58+
y = self.linear1(x)
59+
return self.conv1(y)
60+
class ConvLinearConv(torch.nn.Module):
61+
def __init__(self):
62+
super().__init__()
63+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
64+
self.linear1 = torch.nn.Linear(4, 4)
65+
def forward(self, x):
66+
y = self.conv1(x)
67+
return self.linear1(y)
68+
class Bilinear(torch.nn.Module):
69+
def __init__(self):
70+
super().__init__()
71+
def forward(self, x):
72+
return torch.nn.functional.interpolate(
73+
x, scale_factor=2, mode="bilinear", align_corners=True
74+
)
75+
76+
def test_conv_linear_dim_order_swaps(self):
77+
self.run_tester(self.LinearConv(), (torch.randn(1, 3, 6, 4),))
78+
self.run_tester(self.LinearConv(), (torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),))
79+
80+
def test_linear_conv_dim_order_swaps(self):
81+
self.run_tester(self.ConvLinearConv(), (torch.randn(1, 3, 6, 6),))
82+
self.run_tester(self.ConvLinearConv(), (torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),))
83+
84+
def test_nhwc_input_on_nhwc_op(self):
85+
self.run_tester(self.Bilinear(), (torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32).to(memory_format=torch.channels_last),))
86+
87+
def test_nchw_input_on_nhwc_op(self):
88+
self.run_tester(self.Bilinear(), (torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),))
4589

4690
def test_fp32_channels_last_tagged_reshape_pass(self):
4791
for module, num_reshape in self.modules.items():
@@ -58,6 +102,88 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
58102
.run_method_and_compare_outputs()
59103
)
60104

105+
class LinearConv(torch.nn.Module):
106+
def __init__(self):
107+
super().__init__()
108+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
109+
self.linear1 = torch.nn.Linear(4, 3)
110+
111+
def forward(self, x):
112+
y = self.linear1(x)
113+
return self.conv1(y)
114+
115+
def test_conv_linear_dim_order_swaps_on_nhwc_input(self):
116+
tester = Tester(
117+
self.LinearConv().eval(),
118+
(torch.randn(1, 3, 6, 4).to(memory_format=torch.channels_last),),
119+
)
120+
121+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
122+
123+
def test_conv_linear_dim_order_swaps_on_nchw_input(self):
124+
tester = Tester(
125+
self.LinearConv().eval(),
126+
(torch.randn(1, 3, 6, 4),),
127+
)
128+
129+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
130+
131+
class ConvLinearConv(torch.nn.Module):
132+
def __init__(self):
133+
super().__init__()
134+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
135+
self.linear1 = torch.nn.Linear(4, 4)
136+
137+
def forward(self, x):
138+
y = self.conv1(x)
139+
return self.linear1(y)
140+
141+
def test_linear_conv_dim_order_swaps_on_nhwc_input(self):
142+
tester = Tester(
143+
self.ConvLinearConv().eval(),
144+
(torch.randn(1, 3, 6, 6).to(memory_format=torch.channels_last),),
145+
)
146+
147+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
148+
149+
def test_linear_conv_dim_order_swaps_on_nchw_input(self):
150+
tester = Tester(
151+
self.ConvLinearConv().eval(),
152+
(torch.randn(1, 3, 6, 6),),
153+
)
154+
155+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
156+
157+
class Bilinear(torch.nn.Module):
158+
def __init__(self):
159+
super().__init__()
160+
161+
def forward(self, x):
162+
return torch.nn.functional.interpolate(
163+
x, scale_factor=2, mode="bilinear", align_corners=True
164+
)
165+
166+
def test_nhwc_input_on_nhwc_op(self):
167+
tester = Tester(
168+
self.Bilinear().eval(),
169+
(
170+
torch.arange(8)
171+
.reshape(1, 2, 2, 2)
172+
.to(torch.float32)
173+
.to(memory_format=torch.channels_last),
174+
),
175+
)
176+
177+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
178+
179+
def test_nchw_input_on_nhwc_op(self):
180+
tester = Tester(
181+
self.Bilinear().eval(),
182+
(torch.arange(8).reshape(1, 2, 2, 2).to(torch.float32),),
183+
)
184+
185+
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
186+
61187
def test_qs8_channels_last_tagged_reshape_pass(self):
62188
for module, num_reshape in self.modules.items():
63189
(

backends/xnnpack/test/tester/tester.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from executorch.exir.backend.backend_api import validation_disabled
3333
from executorch.exir.backend.partitioner import Partitioner
34+
from executorch.exir.dim_order_utils import get_memory_format
3435
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
3536

3637
from executorch.exir.print_program import pretty_print, print_program
@@ -533,10 +534,13 @@ def fn(x):
533534
# create random tensor inputs with the shapes given above:
534535
random_inputs = []
535536
for arg_idx in range(len(self.example_inputs)):
537+
memFormat = get_memory_format(
538+
list(self.example_inputs[arg_idx].dim_order())
539+
)
536540
random_inputs.append(
537-
torch.randn(input_shapes[arg_idx]).to(
538-
dtype=self.example_inputs[arg_idx].dtype
539-
)
541+
torch.randn(input_shapes[arg_idx])
542+
.to(dtype=self.example_inputs[arg_idx].dtype)
543+
.to(memory_format=memFormat)
540544
)
541545

542546
yield tuple(random_inputs)

backends/xnnpack/xnnpack_preprocess.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ def preprocess(
145145

146146
node_to_external_map = generate_node_to_external_map(ep, graph_module)
147147

148-
# Make sure all inputs are contiguous_format or NCHW or default dim order
149-
assert_default_dim_order(graph_module)
150-
151148
# TODO retrace the graph module to lift the new params may have
152149
# been added to the graph in passes
153150

0 commit comments

Comments
 (0)