Skip to content

[codegen][opencl] issue with codegen on complex control-flow in opencl #648

Open
@mikepapadim

Description

@mikepapadim

Describe the bug

There is a bug with opencl code gen with complex control flow and the kernel context api:

    public static void calculateAttentionScores(KernelContext context, IntArray positionNlayer, int seqLen, FloatArray query, FloatArray keyCache, FloatArray attScores, int kvDim, int kvMul,
            int headSize, int loff, int localWorkgroupSize) {
        int h = context.groupIdx;         // Head index
        int threadId = context.localIdx;  // Thread ID within work group
        int blockDim = context.localGroupSizeX;  // Work group size

        // Get the query vector offset for this head
        int queryOffset = h * headSize;

        // Attention scores offset for this head
        int attOffset = h * seqLen;
        int position = positionNlayer.get(0);                            <------------------------- This is the root cause

        for (int t = threadId; t <= position; t += blockDim) {  <------------------------- This is the root cause
            // Get the key vector for this head and at this timestep
            int keyOffset = loff + t * kvDim + (h / kvMul) * headSize;

            // Calculate the attention score as the dot product of query and key
            float score = 0.0f;
            for (int i = 0; i < 8192; i++) {
                score += query.get(queryOffset + i) * keyCache.get(keyOffset + i);
            }

            // Scale by sqrt(head_size)
            score /= TornadoMath.sqrt(headSize);

            // Save the score to the attention buffer
            attScores.set(attOffset + t, score);
        }
    }

Produces the following opencl ->

#pragma OPENCL EXTENSION cl_khr_fp64 : enable  
#pragma OPENCL EXTENSION cl_khr_fp16 : enable  
#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable  
__kernel void calculateAttentionScores(__global long *_kernel_context, __constant uchar *_constant_region, __local uchar *_local_region, __global int *_atomics, __global uchar *positionNlayer, __private int seqLen, __global uchar *query, __global uchar *keyCache, __global uchar *attScores, __private int kvDim, __private int kvMul, __private int headSize, __private int loff, __private int localWorkgroupSize)
{
  long l_25, l_24, l_31, l_32, l_19, l_20; 
  int i_30, i_5, i_35, i_23, i_29, i_17, i_14, i_15, i_18, i_8, i_9, i_6, i_7, i_12, i_13, i_10, i_11; 
  float f_22, f_27, f_28, f_16, f_34; 
  ulong ul_26, ul_21, ul_4, ul_3, ul_2, ul_1, ul_33, ul_0; 

  // BLOCK 0
  ul_0  =  (ulong) positionNlayer;
  ul_1  =  (ulong) query;
  ul_2  =  (ulong) keyCache;
  ul_3  =  (ulong) attScores;
  ul_4  =  ul_0 + 24L;
  i_5  =  *((__global int *) ul_4);
  i_6  =  get_local_size(0);
  i_7  =  get_group_id(0);
  i_8  =  i_7 << 4;
  i_9  =  i_8 + 6;
  i_10  =  i_7 << 5;
  i_11  =  i_10 + 6;
  i_12  =  get_local_id(0);
  // BLOCK 1 MERGES [0 5 ]
  i_13  =  i_12;
  for(;i_5 >= i_13;)
  {
    // BLOCK 6
    return;
  }  // B6
  // BLOCK 2
  i_14  =  i_13 << 5;
  i_15  =  i_14 + i_11;
  // BLOCK 3 MERGES [2 4 ]
  f_16  =  0.0F;
  i_17  =  0;
  for(;i_17 < 8192;)
  {
    // BLOCK 4
    i_18  =  i_17 + i_11;
    l_19  =  (long) i_18;
    l_20  =  l_19 << 2;
    ul_21  =  ul_1 + l_20;
    f_22  =  *((__global float *) ul_21);
    i_23  =  i_15 + i_17;
    l_24  =  (long) i_23;
    l_25  =  l_24 << 2;
    ul_26  =  ul_2 + l_25;
    f_27  =  *((__global float *) ul_26);
    f_28  =  fma(f_22, f_27, f_16);
    i_29  =  i_17 + 1;
    f_16  =  f_28;
    i_17  =  i_29;
  }  // B4
  // BLOCK 5
  i_30  =  i_9 + i_13;
  l_31  =  (long) i_30;
  l_32  =  l_31 << 2;
  ul_33  =  ul_3 + l_32;
  f_34  =  f_16 / 5.656854F;
  *((__global float *) ul_33)  =  f_34;
  i_35  =  i_6 + i_13;
  i_13  =  i_35;
}  // B5
}  //  kernel

[TornadoVM-OCL-JNI] ERROR : clBuildProgram -> Returned: -11
lculateAttentionScores.log
<kernel>:66:1: error: extraneous closing brace ('}')
}  //  kernel
^

How To Reproduce

From the following branch:
https://github.com/mikepapadim/TornadoVM/tree/bug/code_gen

make jdk21 
tornado-test --debug --printKernel  -V uk.ac.manchester.tornado.unittests.llm.TestCodeGenBug

Expected behavior

Fails with OpenCL build error:

[TornadoVM-OCL-JNI] ERROR : clBuildProgram -> Returned: -11
lculateAttentionScores.log
<kernel>:66:1: error: extraneous closing brace ('}')
}  //  kernel
^

Computing system setup (please complete the following information):

  • OpenCL and Driver versions

Metadata

Metadata

Labels

bugSomething isn't working

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions