Skip to content

Commit f19892b

Browse files
committed
faster block sort
1 parent 18207fa commit f19892b

File tree

2 files changed

+312
-1
lines changed

2 files changed

+312
-1
lines changed
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
.version 7.0
2+
.target sm_50 // enough for my Titan X
3+
.address_size 64
4+
5+
// Like sort_bitonic_block.ptx, but using the unrolled reduction
6+
// from sort_bitonic_warp_v3.ptx.
7+
8+
.visible .entry sortBitonicBlockV2 (
9+
.param .u64 ptr
10+
) {
11+
.reg .pred %p0;
12+
.reg .pred %reverse;
13+
14+
// Arguments
15+
.reg .u64 %ptr;
16+
17+
// Cached thread properties
18+
.reg .u32 %tidX;
19+
.reg .u32 %tidY;
20+
21+
// Other variables.
22+
.reg .u64 %dtmp<2>;
23+
.reg .u32 %stmp<4>;
24+
.reg .u32 %i;
25+
.reg .u32 %j;
26+
.reg .f32 %val<3>;
27+
.reg .u32 %rank;
28+
.reg .u32 %rankAnd1;
29+
.reg .u32 %rankAnd2;
30+
.reg .u32 %rankAnd4;
31+
.reg .u32 %rankAnd8;
32+
.reg .u32 %rankAnd16;
33+
.reg .u32 %rankAnd32;
34+
35+
.shared .align 4 .f32 sortBuffer[1024];
36+
37+
// Load arguments and thread properties.
38+
ld.param.u64 %ptr, [ptr];
39+
mov.u32 %tidX, %tid.x;
40+
mov.u32 %tidY, %tid.y;
41+
42+
shl.b32 %rank, %tidY, 5;
43+
add.u32 %rank, %rank, %tidX;
44+
and.b32 %rankAnd1, %rank, 1;
45+
and.b32 %rankAnd2, %rank, 2;
46+
and.b32 %rankAnd4, %rank, 4;
47+
and.b32 %rankAnd8, %rank, 8;
48+
and.b32 %rankAnd16, %rank, 16;
49+
and.b32 %rankAnd32, %rank, 32;
50+
51+
cvt.u64.u32 %dtmp0, %ctaid.x;
52+
shl.b64 %dtmp0, %dtmp0, 10;
53+
cvt.u64.u32 %dtmp1, %tidY;
54+
shl.b64 %dtmp1, %dtmp1, 5;
55+
add.u64 %dtmp0, %dtmp0, %dtmp1; // (ctaid.x*1024 + tid.y*32)
56+
cvt.u64.u32 %dtmp1, %tidX;
57+
add.u64 %dtmp0, %dtmp0, %dtmp1;
58+
shl.b64 %dtmp0, %dtmp0, 2; // 4*(ctaid.x*1024 + tid.y*32 + tid.x)
59+
add.u64 %ptr, %ptr, %dtmp0;
60+
ld.global.f32 %val0, [%ptr];
61+
62+
// Sort within each warp.
63+
// i=0
64+
setp.ne.u32 %reverse, %rankAnd2, 0;
65+
// j=0
66+
setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse;
67+
shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff;
68+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
69+
selp.f32 %val0, %val1, %val0, %p0;
70+
71+
// i=1
72+
setp.ne.u32 %reverse, %rankAnd4, 0;
73+
// j=1
74+
setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse;
75+
shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff;
76+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
77+
selp.f32 %val0, %val1, %val0, %p0;
78+
// j=0
79+
setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse;
80+
shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff;
81+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
82+
selp.f32 %val0, %val1, %val0, %p0;
83+
84+
// i=2
85+
setp.ne.u32 %reverse, %rankAnd8, 0;
86+
// j=2
87+
setp.eq.xor.u32 %p0, %rankAnd4, 0, %reverse;
88+
shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff;
89+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
90+
selp.f32 %val0, %val1, %val0, %p0;
91+
// j=1
92+
setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse;
93+
shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff;
94+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
95+
selp.f32 %val0, %val1, %val0, %p0;
96+
// j=0
97+
setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse;
98+
shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff;
99+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
100+
selp.f32 %val0, %val1, %val0, %p0;
101+
102+
// i=3
103+
setp.ne.u32 %reverse, %rankAnd16, 0;
104+
// j=3
105+
setp.eq.xor.u32 %p0, %rankAnd8, 0, %reverse;
106+
shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff;
107+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
108+
selp.f32 %val0, %val1, %val0, %p0;
109+
// j=2
110+
setp.eq.xor.u32 %p0, %rankAnd4, 0, %reverse;
111+
shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff;
112+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
113+
selp.f32 %val0, %val1, %val0, %p0;
114+
// j=1
115+
setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse;
116+
shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff;
117+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
118+
selp.f32 %val0, %val1, %val0, %p0;
119+
// j=0
120+
setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse;
121+
shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff;
122+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
123+
selp.f32 %val0, %val1, %val0, %p0;
124+
125+
// i=4
126+
setp.ne.u32 %reverse, %rankAnd32, 0;
127+
// j=4
128+
setp.eq.xor.u32 %p0, %rankAnd16, 0, %reverse;
129+
shfl.sync.bfly.b32 %val1, %val0, 16, 0x1f, 0xffffffff;
130+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
131+
selp.f32 %val0, %val1, %val0, %p0;
132+
// j=3
133+
setp.eq.xor.u32 %p0, %rankAnd8, 0, %reverse;
134+
shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff;
135+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
136+
selp.f32 %val0, %val1, %val0, %p0;
137+
// j=2
138+
setp.eq.xor.u32 %p0, %rankAnd4, 0, %reverse;
139+
shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff;
140+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
141+
selp.f32 %val0, %val1, %val0, %p0;
142+
// j=1
143+
setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse;
144+
shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff;
145+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
146+
selp.f32 %val0, %val1, %val0, %p0;
147+
// j=0
148+
setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse;
149+
shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff;
150+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
151+
selp.f32 %val0, %val1, %val0, %p0;
152+
153+
// Our index in shared will be stmp0 = 4*(tidX + tidY*32)
154+
// We start by storing the values so that each 32-float block
155+
// is sorted (or reversed, to make 64-float bitonic chunks).
156+
shl.b32 %stmp0, %rank, 2;
157+
mov.u32 %stmp1, sortBuffer;
158+
add.u32 %stmp1, %stmp1, %stmp0;
159+
st.shared.f32 [%stmp1], %val0;
160+
161+
// Sort from stride 32 to stride 512, by striding
162+
// tid.y and doing the resulting indexing logic.
163+
mov.u32 %i, 0;
164+
block_loop_start:
165+
166+
// We merge pairs from shared memory, so we only use half
167+
// of the block for this inner loop.
168+
// We still want to make sure all writes from all ranks
169+
// are finished writing to shared memory.
170+
bar.sync 0;
171+
setp.gt.u32 %p0, %tidY, 15;
172+
@%p0 bra inner_block_loop_end;
173+
174+
// Merge across warps, avoiding bank conflicts by reading
175+
// consecutive values of at least 32 floats.
176+
// This logic reads and writes two values from each warp.
177+
mov.u32 %j, %i;
178+
inner_block_loop_start:
179+
// We will store a "virtual" tid.y in stmp2 by effectively
180+
// moving bit %j to the 16 position (most significant digit).
181+
//
182+
// If tid.y % (1<<j) == 0, then we actually read the warp
183+
// corresponding to tid.y and tid.y ^ (1<<j).
184+
// Otherwise, we are doing the second half and we start at
185+
// (tid.y^(1<<j))+16 and also do (tid.y + 16).
186+
shl.b32 %stmp0, 1, %j;
187+
and.b32 %stmp1, %tidY, %stmp0;
188+
xor.b32 %stmp2, %tidY, %stmp1;
189+
setp.ne.u32 %p0, %stmp1, 0;
190+
@%p0 or.b32 %stmp2, %stmp2, 16;
191+
192+
// Decide if we are in the second or first half of the next
193+
// stage's bitonic input.
194+
// We do this inside the loop because, if i=3, then we will start the
195+
// second half based on the extra bit we added to stmp2.
196+
shl.b32 %stmp3, 2, %i;
197+
and.b32 %stmp3, %stmp2, %stmp3;
198+
setp.ne.u32 %reverse, %stmp3, 0;
199+
200+
// Compute two addresses for shared memory and store them into
201+
// stmp0 and stmp3.
202+
shl.b32 %stmp1, %stmp2, 5;
203+
add.u32 %stmp1, %stmp1, %tidX;
204+
shl.b32 %stmp1, %stmp1, 2;
205+
mov.u32 %stmp3, sortBuffer;
206+
add.u32 %stmp0, %stmp3, %stmp1;
207+
// xor effective tid.y with (1 << j) to get second address
208+
// note that we overwrite %stmp2 here.
209+
shl.b32 %stmp2, 128, %j;
210+
xor.b32 %stmp1, %stmp1, %stmp2;
211+
add.u32 %stmp3, %stmp3, %stmp1;
212+
213+
bar.sync 1, 512; // only half the block is participating
214+
215+
ld.shared.f32 %val0, [%stmp0];
216+
ld.shared.f32 %val1, [%stmp3];
217+
218+
// Swap based on comparison, possibly reversing.
219+
setp.gt.xor.f32 %p0, %val0, %val1, %reverse;
220+
@%p0 mov.f32 %val2, %val0;
221+
@%p0 mov.f32 %val0, %val1;
222+
@%p0 mov.f32 %val1, %val2;
223+
224+
st.shared.f32 [%stmp0], %val0;
225+
st.shared.f32 [%stmp3], %val1;
226+
227+
setp.ne.u32 %p0, %j, 0;
228+
sub.u32 %j, %j, 1;
229+
@%p0 bra inner_block_loop_start;
230+
inner_block_loop_end:
231+
232+
// We are back to working from all of the block, not just half.
233+
bar.sync 0;
234+
235+
// We now must merge within each 32-float block, which we do per-warp
236+
// to once again avoid bank conflicts.
237+
// This looks very similar to the per-warp sorting logic we started with.
238+
shl.b32 %stmp0, %tidY, 5;
239+
add.u32 %stmp0, %stmp0, %tidX;
240+
shl.b32 %stmp0, %stmp0, 2;
241+
mov.u32 %stmp1, sortBuffer;
242+
add.u32 %stmp3, %stmp0, %stmp1;
243+
ld.shared.f32 %val0, [%stmp3];
244+
245+
// We are in the second half based on the full tid.y.
246+
shl.b32 %stmp0, 2, %i;
247+
and.b32 %stmp0, %tidY, %stmp0;
248+
setp.ne.u32 %reverse, %stmp0, 0;
249+
250+
// j=4
251+
setp.eq.xor.u32 %p0, %rankAnd16, 0, %reverse;
252+
shfl.sync.bfly.b32 %val1, %val0, 16, 0x1f, 0xffffffff;
253+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
254+
selp.f32 %val0, %val1, %val0, %p0;
255+
// j=3
256+
setp.eq.xor.u32 %p0, %rankAnd8, 0, %reverse;
257+
shfl.sync.bfly.b32 %val1, %val0, 8, 0x1f, 0xffffffff;
258+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
259+
selp.f32 %val0, %val1, %val0, %p0;
260+
// j=2
261+
setp.eq.xor.u32 %p0, %rankAnd4, 0, %reverse;
262+
shfl.sync.bfly.b32 %val1, %val0, 4, 0x1f, 0xffffffff;
263+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
264+
selp.f32 %val0, %val1, %val0, %p0;
265+
// j=1
266+
setp.eq.xor.u32 %p0, %rankAnd2, 0, %reverse;
267+
shfl.sync.bfly.b32 %val1, %val0, 2, 0x1f, 0xffffffff;
268+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
269+
selp.f32 %val0, %val1, %val0, %p0;
270+
// j=0
271+
setp.eq.xor.u32 %p0, %rankAnd1, 0, %reverse;
272+
shfl.sync.bfly.b32 %val1, %val0, 1, 0x1f, 0xffffffff;
273+
setp.lt.xor.f32 %p0, %val0, %val1, %p0;
274+
selp.f32 %val0, %val1, %val0, %p0;
275+
276+
st.shared.f32 [%stmp3], %val0;
277+
278+
add.u32 %i, %i, 1;
279+
setp.lt.u32 %p0, %i, 5;
280+
@%p0 bra block_loop_start;
281+
block_loop_end:
282+
283+
// Store values back from shared memory.
284+
bar.sync 0;
285+
shl.b32 %stmp0, %tidY, 5;
286+
add.u32 %stmp0, %stmp0, %tidX;
287+
shl.b32 %stmp0, %stmp0, 2;
288+
mov.u32 %stmp1, sortBuffer;
289+
add.u32 %stmp0, %stmp0, %stmp1;
290+
ld.shared.f32 %val0, [%stmp0];
291+
st.global.f32 [%ptr], %val0;
292+
293+
ret;
294+
}

learn_ptx/sort.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,22 @@ def sort_bitonic_block():
7777
assert np.allclose(results, expected), f"\n{results=}\n{expected=}"
7878

7979

80+
def sort_bitonic_block_v2():
81+
fn = compile_function("sort_bitonic_block_v2.ptx", "sortBitonicBlockV2")
82+
inputs = np.random.normal(size=[16384, 1024]).astype(np.float32)
83+
input_buf = numpy_to_gpu(inputs)
84+
with measure_time() as timer:
85+
fn(
86+
input_buf,
87+
grid=(inputs.shape[0], 1, 1),
88+
block=(32, 32, 1),
89+
)
90+
sync()
91+
results = gpu_to_numpy(input_buf, inputs.shape, inputs.dtype)
92+
expected = np.sort(inputs, axis=-1)
93+
print(f"took {timer()} seconds")
94+
assert np.allclose(results, expected), f"\n{results=}\n{expected=}"
95+
96+
8097
if __name__ == "__main__":
81-
sort_bitonic_warp_v3()
98+
sort_bitonic_block_v2()

0 commit comments

Comments
 (0)