@@ -528,12 +528,39 @@ struct State {
528
528
}
529
529
}
530
530
}
531
+
531
532
// Outputs must be vectorized over their innermost
532
533
// dimension, because we don't have control of the
533
- // storage. TODO: Go inspect to see which dimension has a
534
- // stride==1 constraint instead of assuming 0.
534
+ // storage. Infer which dimension(s) is(are) the innermost one(s) by
535
+ // looking at the stride. Note that there can be more than one in
536
+ // case some dimensions have an extent of 1.
537
+ if (node->is_output && !node->func .output_buffers ().empty ()) {
538
+ const Parameter &output = node->func .output_buffers ()[0 ];
539
+ int num_dims = output.dimensions ();
540
+ for (int i = 0 ; i < num_dims; ++i) {
541
+ const Expr stride = output.stride_constraint (i);
542
+ const int64_t *s = as_const_int (stride);
543
+ if (s && *s == 1 ) {
544
+ vector_dims.push_back (i);
545
+ }
546
+ }
547
+ }
548
+
535
549
if (vector_dims.empty ()) {
536
- vector_dims.push_back (0 );
550
+ // This can happen if the output strides aren't known, or if all
551
+ // the dimensions are smaller than the vector size.
552
+ // TBD: consider extending compute_in_tiles to support -1 as a
553
+ // vector dim to indicate no vectorization.
554
+ for (int v = 0 ; v < node->dimensions ; v++) {
555
+ vector_dims.push_back (v);
556
+ }
557
+ // Handle the case of full reductions that generate a scalar.
558
+ // We need at least one vector dimension to call cmopute_in_tiles
559
+ // below.
560
+ // TBD: figure out a better fallback strategy.
561
+ if (vector_dims.empty ()) {
562
+ vector_dims.push_back (0 );
563
+ }
537
564
}
538
565
539
566
// 2) Realize it somewhere
0 commit comments