Open
Description
Describe the bug
There is a bug with opencl code gen with complex control flow and the kernel context api:
public static void calculateAttentionScores(KernelContext context, IntArray positionNlayer, int seqLen, FloatArray query, FloatArray keyCache, FloatArray attScores, int kvDim, int kvMul,
int headSize, int loff, int localWorkgroupSize) {
int h = context.groupIdx; // Head index
int threadId = context.localIdx; // Thread ID within work group
int blockDim = context.localGroupSizeX; // Work group size
// Get the query vector offset for this head
int queryOffset = h * headSize;
// Attention scores offset for this head
int attOffset = h * seqLen;
int position = positionNlayer.get(0); <------------------------- This is the root cause
for (int t = threadId; t <= position; t += blockDim) { <------------------------- This is the root cause
// Get the key vector for this head and at this timestep
int keyOffset = loff + t * kvDim + (h / kvMul) * headSize;
// Calculate the attention score as the dot product of query and key
float score = 0.0f;
for (int i = 0; i < 8192; i++) {
score += query.get(queryOffset + i) * keyCache.get(keyOffset + i);
}
// Scale by sqrt(head_size)
score /= TornadoMath.sqrt(headSize);
// Save the score to the attention buffer
attScores.set(attOffset + t, score);
}
}
Produces the following opencl ->
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
__kernel void calculateAttentionScores(__global long *_kernel_context, __constant uchar *_constant_region, __local uchar *_local_region, __global int *_atomics, __global uchar *positionNlayer, __private int seqLen, __global uchar *query, __global uchar *keyCache, __global uchar *attScores, __private int kvDim, __private int kvMul, __private int headSize, __private int loff, __private int localWorkgroupSize)
{
long l_25, l_24, l_31, l_32, l_19, l_20;
int i_30, i_5, i_35, i_23, i_29, i_17, i_14, i_15, i_18, i_8, i_9, i_6, i_7, i_12, i_13, i_10, i_11;
float f_22, f_27, f_28, f_16, f_34;
ulong ul_26, ul_21, ul_4, ul_3, ul_2, ul_1, ul_33, ul_0;
// BLOCK 0
ul_0 = (ulong) positionNlayer;
ul_1 = (ulong) query;
ul_2 = (ulong) keyCache;
ul_3 = (ulong) attScores;
ul_4 = ul_0 + 24L;
i_5 = *((__global int *) ul_4);
i_6 = get_local_size(0);
i_7 = get_group_id(0);
i_8 = i_7 << 4;
i_9 = i_8 + 6;
i_10 = i_7 << 5;
i_11 = i_10 + 6;
i_12 = get_local_id(0);
// BLOCK 1 MERGES [0 5 ]
i_13 = i_12;
for(;i_5 >= i_13;)
{
// BLOCK 6
return;
} // B6
// BLOCK 2
i_14 = i_13 << 5;
i_15 = i_14 + i_11;
// BLOCK 3 MERGES [2 4 ]
f_16 = 0.0F;
i_17 = 0;
for(;i_17 < 8192;)
{
// BLOCK 4
i_18 = i_17 + i_11;
l_19 = (long) i_18;
l_20 = l_19 << 2;
ul_21 = ul_1 + l_20;
f_22 = *((__global float *) ul_21);
i_23 = i_15 + i_17;
l_24 = (long) i_23;
l_25 = l_24 << 2;
ul_26 = ul_2 + l_25;
f_27 = *((__global float *) ul_26);
f_28 = fma(f_22, f_27, f_16);
i_29 = i_17 + 1;
f_16 = f_28;
i_17 = i_29;
} // B4
// BLOCK 5
i_30 = i_9 + i_13;
l_31 = (long) i_30;
l_32 = l_31 << 2;
ul_33 = ul_3 + l_32;
f_34 = f_16 / 5.656854F;
*((__global float *) ul_33) = f_34;
i_35 = i_6 + i_13;
i_13 = i_35;
} // B5
} // kernel
[TornadoVM-OCL-JNI] ERROR : clBuildProgram -> Returned: -11
lculateAttentionScores.log
<kernel>:66:1: error: extraneous closing brace ('}')
} // kernel
^
How To Reproduce
From the following branch:
https://github.com/mikepapadim/TornadoVM/tree/bug/code_gen
make jdk21
tornado-test --debug --printKernel -V uk.ac.manchester.tornado.unittests.llm.TestCodeGenBug
Expected behavior
Fails with OpenCL build error:
[TornadoVM-OCL-JNI] ERROR : clBuildProgram -> Returned: -11
lculateAttentionScores.log
<kernel>:66:1: error: extraneous closing brace ('}')
} // kernel
^
Computing system setup (please complete the following information):
- OpenCL and Driver versions