]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: support ne3 in get_rows (#15866)
authorlhez <redacted>
Tue, 30 Sep 2025 16:55:13 +0000 (09:55 -0700)
committerGitHub <redacted>
Tue, 30 Sep 2025 16:55:13 +0000 (09:55 -0700)
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/get_rows.cl

index 0cf3b92464c6e077d8adb26c8c01acc03ea9276b..a9405ab012dc179b63854198d1631e0b69306a06 100644 (file)
@@ -4222,15 +4222,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
 
-    const int      ne00 = src0 ? src0->ne[0] : 0;
-    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
-    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
-    const int      ne10 = src1 ? src1->ne[0] : 0;
-    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
-    const int      ne11 = src1 ? src1->ne[1] : 0;
-    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
-    const cl_ulong nb1  = dst  ?  dst->nb[1] : 0;
-    const cl_ulong nb2  = dst  ?  dst->nb[2] : 0;
+    const int      ne00 = src0->ne[0];
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
+    const int      ne10 = src1->ne[0];
+    const cl_ulong nb10 = src1->nb[0];
+    const int      ne11 = src1->ne[1];
+    const int      ne12 = src1->ne[2];
+    const cl_ulong nb11 = src1->nb[1];
+    const cl_ulong nb12 = src1->nb[2];
+    const cl_ulong nb1  = dst->nb[1];
+    const cl_ulong nb2  = dst->nb[2];
+    const cl_ulong nb3  = dst->nb[3];
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -4267,14 +4271,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
     CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
     CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
-    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
-
-    size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1};
-    size_t local_work_size[] = {1, 1, 1};
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne10));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
+
+    size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
+    size_t local_work_size[] = {64, 1, 1};
 
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
index b3fea2923df8f91f357249bc1ebbaabea229742d..c2962edc983723e91944307be210b804c9220e78 100644 (file)
@@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32(
         int ne00,
         ulong nb01,
         ulong nb02,
+        ulong nb03,
         int ne10,
         ulong nb10,
         ulong nb11,
+        ulong nb12,
         ulong nb1,
-        ulong nb2
+        ulong nb2,
+        ulong nb3
 ) {
     src0 = (global void*)((global char*)src0 + offset0);
     src1 = (global int*)((global char*)src1 + offset1);
@@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32(
 
     int i10 = get_group_id(0);
     int i11 = get_group_id(1);
+    int i12 = get_group_id(2);
 
-    int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+    int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
 
     int i02 = i11;
+    int i03 = i12;
 
     for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
-        ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
-            ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
+        if (ind >= ne00) {
+            return;
+        }
+        ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
+            ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
     }
 }
 
@@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16(
         int ne00,
         ulong nb01,
         ulong nb02,
+        ulong nb03,
         int ne10,
         ulong nb10,
         ulong nb11,
+        ulong nb12,
         ulong nb1,
-        ulong nb2
+        ulong nb2,
+        ulong nb3
 ) {
     src0 = (global void*)((global char*)src0 + offset0);
     src1 = (global int*)((global char*)src1 + offset1);
@@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16(
 
     int i10 = get_group_id(0);
     int i11 = get_group_id(1);
+    int i12 = get_group_id(2);
 
-    int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+    int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
 
     int i02 = i11;
+    int i03 = i12;
 
     for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
-        ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
-            ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
+        if (ind >= ne00) {
+            return;
+        }
+        ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
+            ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
     }
 }
 
@@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0(
         int ne00,
         ulong nb01,
         ulong nb02,
+        ulong nb03,
         int ne10,
         ulong nb10,
         ulong nb11,
+        ulong nb12,
         ulong nb1,
-        ulong nb2
+        ulong nb2,
+        ulong nb3
 ) {
     src0 = (global void*)((global char*)src0 + offset0);
     src1 = (global int*)((global char*)src1 + offset1);
@@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0(
 
     int i10 = get_group_id(0);
     int i11 = get_group_id(1);
+    int i12 = get_group_id(2);
 
-    int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+    int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
 
     int i02 = i11;
+    int i03 = i12;
 
     for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
         float16 temp;
+        if (ind >= ne00) {
+            return;
+        }
         dequantize_q4_0_f32(
-            ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
-        *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+            ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
+        *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
     }
 }