@@ -42,6 +42,50 @@ def setUp(self):
42
42
"executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
43
43
)
44
44
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
45
+ def run_tester (self , module , inputs ):
46
+ tester = Tester (
47
+ module .eval (),
48
+ inputs ,
49
+ )
50
+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
51
+
52
+ class LinearConv (torch .nn .Module ):
53
+ def __init__ (self ):
54
+ super ().__init__ ()
55
+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
56
+ self .linear1 = torch .nn .Linear (4 , 3 )
57
+ def forward (self , x ):
58
+ y = self .linear1 (x )
59
+ return self .conv1 (y )
60
+ class ConvLinearConv (torch .nn .Module ):
61
+ def __init__ (self ):
62
+ super ().__init__ ()
63
+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
64
+ self .linear1 = torch .nn .Linear (4 , 4 )
65
+ def forward (self , x ):
66
+ y = self .conv1 (x )
67
+ return self .linear1 (y )
68
+ class Bilinear (torch .nn .Module ):
69
+ def __init__ (self ):
70
+ super ().__init__ ()
71
+ def forward (self , x ):
72
+ return torch .nn .functional .interpolate (
73
+ x , scale_factor = 2 , mode = "bilinear" , align_corners = True
74
+ )
75
+
76
+ def test_conv_linear_dim_order_swaps (self ):
77
+ self .run_tester (self .LinearConv (), (torch .randn (1 , 3 , 6 , 4 ),))
78
+ self .run_tester (self .LinearConv (), (torch .randn (1 , 3 , 6 , 4 ).to (memory_format = torch .channels_last ),))
79
+
80
+ def test_linear_conv_dim_order_swaps (self ):
81
+ self .run_tester (self .ConvLinearConv (), (torch .randn (1 , 3 , 6 , 6 ),))
82
+ self .run_tester (self .ConvLinearConv (), (torch .randn (1 , 3 , 6 , 6 ).to (memory_format = torch .channels_last ),))
83
+
84
+ def test_nhwc_input_on_nhwc_op (self ):
85
+ self .run_tester (self .Bilinear (), (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ).to (memory_format = torch .channels_last ),))
86
+
87
+ def test_nchw_input_on_nhwc_op (self ):
88
+ self .run_tester (self .Bilinear (), (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ),))
45
89
46
90
def test_fp32_channels_last_tagged_reshape_pass (self ):
47
91
for module , num_reshape in self .modules .items ():
@@ -58,6 +102,88 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
58
102
.run_method_and_compare_outputs ()
59
103
)
60
104
105
+ class LinearConv (torch .nn .Module ):
106
+ def __init__ (self ):
107
+ super ().__init__ ()
108
+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
109
+ self .linear1 = torch .nn .Linear (4 , 3 )
110
+
111
+ def forward (self , x ):
112
+ y = self .linear1 (x )
113
+ return self .conv1 (y )
114
+
115
+ def test_conv_linear_dim_order_swaps_on_nhwc_input (self ):
116
+ tester = Tester (
117
+ self .LinearConv ().eval (),
118
+ (torch .randn (1 , 3 , 6 , 4 ).to (memory_format = torch .channels_last ),),
119
+ )
120
+
121
+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
122
+
123
+ def test_conv_linear_dim_order_swaps_on_nchw_input (self ):
124
+ tester = Tester (
125
+ self .LinearConv ().eval (),
126
+ (torch .randn (1 , 3 , 6 , 4 ),),
127
+ )
128
+
129
+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
130
+
131
+ class ConvLinearConv (torch .nn .Module ):
132
+ def __init__ (self ):
133
+ super ().__init__ ()
134
+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
135
+ self .linear1 = torch .nn .Linear (4 , 4 )
136
+
137
+ def forward (self , x ):
138
+ y = self .conv1 (x )
139
+ return self .linear1 (y )
140
+
141
+ def test_linear_conv_dim_order_swaps_on_nhwc_input (self ):
142
+ tester = Tester (
143
+ self .ConvLinearConv ().eval (),
144
+ (torch .randn (1 , 3 , 6 , 6 ).to (memory_format = torch .channels_last ),),
145
+ )
146
+
147
+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
148
+
149
+ def test_linear_conv_dim_order_swaps_on_nchw_input (self ):
150
+ tester = Tester (
151
+ self .ConvLinearConv ().eval (),
152
+ (torch .randn (1 , 3 , 6 , 6 ),),
153
+ )
154
+
155
+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
156
+
157
+ class Bilinear (torch .nn .Module ):
158
+ def __init__ (self ):
159
+ super ().__init__ ()
160
+
161
+ def forward (self , x ):
162
+ return torch .nn .functional .interpolate (
163
+ x , scale_factor = 2 , mode = "bilinear" , align_corners = True
164
+ )
165
+
166
+ def test_nhwc_input_on_nhwc_op (self ):
167
+ tester = Tester (
168
+ self .Bilinear ().eval (),
169
+ (
170
+ torch .arange (8 )
171
+ .reshape (1 , 2 , 2 , 2 )
172
+ .to (torch .float32 )
173
+ .to (memory_format = torch .channels_last ),
174
+ ),
175
+ )
176
+
177
+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
178
+
179
+ def test_nchw_input_on_nhwc_op (self ):
180
+ tester = Tester (
181
+ self .Bilinear ().eval (),
182
+ (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ),),
183
+ )
184
+
185
+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
186
+
61
187
def test_qs8_channels_last_tagged_reshape_pass (self ):
62
188
for module , num_reshape in self .modules .items ():
63
189
(
0 commit comments