Skip to content

Commit b59f5cc

Browse files
authored
[ET-VK] New implementation of cat operator (#11623)
## Changes * Introduce `concat_texture.glsl` and `concat_buffer.glsl` to implement the `torch.cat` operator * Introduce `Concat.cpp` to replace `Cat.cpp` * Fix a bug with channels-packed buffer tensors where input data would be copied incorrectly with multiple dims have a stride of 1 ## Motivation > * Introduce `concat_texture.glsl` and `concat_buffer.glsl` to implement the `torch.cat` operator > * Introduce `Concat.cpp` to replace `Cat.cpp` The existing implementation of `torch.cat` uses the copy_channel_offset` shaders. However, these shaders have a critical bug where the output tensor is passed in separately with difference access types, i.e. ``` graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), global_size, local_size, // Inputs and Outputs { {out, vkapi::kWrite}, {out, vkapi::kRead}, {in, vkapi::kRead}, }, ``` This creates many validation layer errors because the memory barriers for the resource cannot be formed properly. The shader essentially relies on undefined behaviour to work correctly. The result is that the `cat` operator produces incorrect result on many platforms. Rather than fix the `copy_offset` shaders, I decided to just introduce new shaders to perform the concat operation. The new implementation handles both buffer and texture inputs and is agnostic to memory layout. Differential Revision: [D76305343](https://our.internmc.facebook.com/intern/diff/D76305343/)
1 parent e62a4ef commit b59f5cc

File tree

10 files changed

+448
-106
lines changed

10 files changed

+448
-106
lines changed

backends/vulkan/op_registry.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,6 @@ def register_rotary_emb_op(features: OpFeatures):
538538
exir_ops.edge.aten.clone.default,
539539
exir_ops.edge.aten.permute.default,
540540
exir_ops.edge.aten.permute_copy.default,
541-
exir_ops.edge.aten.select_copy.int,
542-
exir_ops.edge.aten.slice_copy.Tensor,
543541
exir_ops.edge.aten.view_copy.default,
544542
]
545543
)
@@ -551,6 +549,48 @@ def register_view_ops(features: OpFeatures):
551549
return features
552550

553551

552+
# Fully featured transfer operators (i.e. operators that copy data from the input
553+
# tensor(s) to the output tensor(s)), which have memory layout agnostic implementations
554+
# for both texture and buffer storage types.
555+
@update_features(exir_ops.edge.aten.cat.default)
556+
def register_cat_op(features: OpFeatures):
557+
features.texture_impl = TextureImplFeatures(
558+
valid_packed_dims=all_packed_dims,
559+
)
560+
features.buffer_impl = True
561+
features.resize_fn = True
562+
563+
def check_cat_node(node: torch.fx.Node) -> bool:
564+
inputs = node.args[0]
565+
if isinstance(inputs, (list, tuple)) and len(inputs) <= 3:
566+
return True
567+
568+
return False
569+
570+
features.check_node_fn = check_cat_node
571+
572+
return features
573+
574+
575+
# Fully featured transfer operators (i.e. operators that copy data from the input
576+
# tensor(s) to the output tensor(s)), which have memory layout agnostic implementations
577+
# for both texture and buffer storage types.
578+
@update_features(
579+
[
580+
exir_ops.edge.aten.select_copy.int,
581+
exir_ops.edge.aten.slice_copy.Tensor,
582+
]
583+
)
584+
def register_transfer_ops(features: OpFeatures):
585+
features.texture_impl = TextureImplFeatures(
586+
valid_packed_dims=all_packed_dims,
587+
)
588+
features.buffer_impl = True
589+
features.resize_fn = True
590+
591+
return features
592+
593+
554594
# Ops ported from PyTorch Vulkan backend. These ops commonly support channels
555595
# packed tensors only and do not have a resize function.
556596
@update_features(
@@ -588,7 +628,6 @@ def register_ported_op(features: OpFeatures):
588628
exir_ops.edge.aten.squeeze_copy.dims,
589629
exir_ops.edge.aten.unsqueeze_copy.default,
590630
# Tensor combination
591-
exir_ops.edge.aten.cat.default,
592631
exir_ops.edge.aten.repeat.default,
593632
exir_ops.edge.aten.split_with_sizes_copy.default,
594633
exir_ops.edge.aten.split.Tensor,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
15+
16+
${define_active_storage_type("buffer")}
17+
${define_required_extensions(DTYPE)}
18+
19+
layout(std430) buffer;
20+
21+
#include "indexing_utils.h"
22+
23+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
24+
25+
$for i in range(NUM_INPUTS):
26+
${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "buffer")}
27+
28+
${layout_declare_ubo(B, "int", "concat_dim")}
29+
30+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
31+
${layout_declare_ubo(B, "ivec4", "out_strides")}
32+
33+
$for i in range(NUM_INPUTS):
34+
${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_sizes")}
35+
${layout_declare_ubo(B, "ivec4", "in" + str(i+1) + "_strides")}
36+
37+
${layout_declare_ubo(B, "int", "out_numel")}
38+
39+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
40+
41+
const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
42+
43+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
44+
45+
void main() {
46+
const int out_bufi = ivec3(gl_GlobalInvocationID).x;
47+
if (out_bufi >= out_numel) {
48+
return;
49+
}
50+
51+
// Convert buffer linear index to 4-D tensor index for output
52+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order);
53+
54+
// Determine which input tensor to read from
55+
ivec4 in_tidx = out_tidx;
56+
57+
$for i in range(NUM_INPUTS):
58+
// Check if the index at the concat dim is within bounds of the input tensor
59+
// If so, read from that input tensor and write to output
60+
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
61+
int in_bufi = tidx_to_bufi(in_tidx, in${i+1}_strides);
62+
t_out[out_bufi] = t_in${i+1}[in_bufi];
63+
return;
64+
}
65+
// otherwise, decrement the index at the concat dim
66+
else {
67+
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
68+
}
69+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
concat_buffer:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NUM_INPUTS: 2
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: concat_1_buffer
11+
NUM_INPUTS: 1
12+
- NAME: concat_2_buffer
13+
- NAME: concat_3_buffer
14+
NUM_INPUTS: 3
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
#define T ${buffer_scalar_type(DTYPE)}
15+
16+
#define USING_TEXTURE3D
17+
18+
layout(std430) buffer;
19+
20+
#include "indexing_utils.h"
21+
22+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
23+
24+
$for i in range(NUM_INPUTS):
25+
${layout_declare_tensor(B, "r", "t_in" + str(i + 1), DTYPE, "texture3d")}
26+
27+
${layout_declare_ubo(B, "int", "concat_dim")}
28+
29+
$in_metadata = ""
30+
$for i in range(NUM_INPUTS):
31+
$in_metadata += "ivec4 in" + str(i + 1) + "_sizes;\n"
32+
33+
layout(push_constant) uniform restrict Block {
34+
ivec4 out_sizes;
35+
${in_metadata}
36+
};
37+
38+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
39+
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
40+
const lowp int out_packed_dim = unhash_packed_dim(out_layout);
41+
42+
$for i in range(NUM_INPUTS):
43+
${layout_declare_spec_const(C, "int", "in" + str(i+1) + "_layout", "DEFAULT_LAYOUT")}
44+
const lowp ivec4 in${i+1}_axis_map = unhash_axis_map(in${i+1}_layout);
45+
const lowp int in${i+1}_packed_dim = unhash_packed_dim(in${i+1}_layout);
46+
47+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
48+
49+
// Check if we can use the fast path (no texel merging required)
50+
bool can_use_fast_path() {
51+
// Fast path is possible when:
52+
// 1. The concat dimension is not the packed dimension, or
53+
// 2. The concat dimension is the packed dimension but both input tensors have dimensions
54+
// that are multiples of 4 along the packed dimension
55+
if (concat_dim != out_packed_dim) {
56+
return true;
57+
}
58+
59+
// Check if all input tensors have dimensions that are multiples of 4 along the packed dimension
60+
bool all_concat_dim_size_multiple_of_4 = true;
61+
$for i in range(NUM_INPUTS):
62+
all_concat_dim_size_multiple_of_4 =
63+
all_concat_dim_size_multiple_of_4 &&
64+
(in${i+1}_sizes[concat_dim] % 4 == 0);
65+
66+
return all_concat_dim_size_multiple_of_4;
67+
}
68+
69+
void main() {
70+
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
71+
ivec4 out_tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, out_packed_dim);
72+
73+
if (any(greaterThanEqual(out_tidx, out_sizes))) {
74+
return;
75+
}
76+
77+
if (can_use_fast_path()) {
78+
// Fast path: No texel merging required
79+
ivec4 in_tidx = out_tidx;
80+
81+
$for i in range(NUM_INPUTS):
82+
// For each input tensor, check if the tensor index is within bounds. If
83+
// so, read the texel from the input tensor and write it to the output
84+
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
85+
const ivec3 in_pos = tidx_to_pos(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim);
86+
const VEC4_T in_texel = load_texel(t_in${i+1}, in_pos);
87+
write_texel_lpos(t_out, lpos, in_texel, out_axis_map);
88+
return;
89+
}
90+
// Otherwise, adjust the index along the concat dimension and try the next
91+
// input tensor.
92+
else {
93+
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
94+
}
95+
}
96+
else {
97+
// Slow path: Texel merging required
98+
VEC4_T out_texel = VEC4_T(0);
99+
100+
// Process each element in the output texel individually
101+
for (int texel_i = 0; texel_i < 4; ++texel_i) {
102+
ivec4 curr_out_tidx = out_tidx;
103+
curr_out_tidx[out_packed_dim] += texel_i;
104+
105+
// Skip if we're out of bounds
106+
if (curr_out_tidx[out_packed_dim] >= out_sizes[out_packed_dim]) {
107+
continue;
108+
}
109+
110+
ivec4 in_tidx = curr_out_tidx;
111+
$for i in range(NUM_INPUTS):
112+
// For each input tensor, check if the tensor index is within bounds. If
113+
// so, read the corresponding texel element from the input tensor and
114+
// write it to the output texel.
115+
if (in_tidx[concat_dim] < in${i+1}_sizes[concat_dim]) {
116+
const ivec4 in_posi = tidx_to_posi(in_tidx, in${i+1}_sizes, in${i+1}_axis_map, in${i+1}_packed_dim);
117+
out_texel[texel_i] = load_texel(t_in${i+1}, in_posi.xyz)[in_posi.w];
118+
continue;
119+
}
120+
// Otherwise, adjust the index along the concat dimension and try the
121+
// next input tensor.
122+
else {
123+
in_tidx[concat_dim] -= in${i+1}_sizes[concat_dim];
124+
}
125+
}
126+
127+
write_texel_lpos(t_out, lpos, out_texel, out_axis_map);
128+
}
129+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
concat_texture:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
NUM_INPUTS: 2
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: concat_1_texture3d
11+
NUM_INPUTS: 1
12+
- NAME: concat_2_texture3d
13+
- NAME: concat_3_texture3d
14+
NUM_INPUTS: 3

backends/vulkan/runtime/graph/ops/impl/Cat.cpp

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)