]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: support pad_ext (#15888)
authorlhez <redacted>
Tue, 30 Sep 2025 17:45:45 +0000 (10:45 -0700)
committerGitHub <redacted>
Tue, 30 Sep 2025 17:45:45 +0000 (10:45 -0700)
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/pad.cl

index a9405ab012dc179b63854198d1631e0b69306a06..79d2148744f904874836245ef7d4f46184b138c9 100644 (file)
@@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
         case GGML_OP_REPEAT:
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
         case GGML_OP_PAD:
-            return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
-                   op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
-                   (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
-                   (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
+            return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_UPSCALE:
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_CONV_2D:
@@ -5881,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
     GGML_ASSERT(dst->extra);
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1);
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -5899,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t
     const int s_ne0 = src0->ne[0];
     const int s_ne1 = src0->ne[1];
     const int s_ne2 = src0->ne[2];
+    const int s_ne3 = src0->ne[3];
+
+    const int s_nb0 = src0->nb[0];
+    const int s_nb1 = src0->nb[1];
+    const int s_nb2 = src0->nb[2];
+    const int s_nb3 = src0->nb[3];
 
     const int d_ne0 = dst->ne[0];
     const int d_ne1 = dst->ne[1];
     const int d_ne2 = dst->ne[2];
+    const int d_ne3 = dst->ne[3];
+
+    const int d_nb0 = dst->nb[0];
+    const int d_nb1 = dst->nb[1];
+    const int d_nb2 = dst->nb[2];
+    const int d_nb3 = dst->nb[3];
+
+    const int lp0 = ((const int*)(dst->op_params))[0];
+    const int rp0 = ((const int*)(dst->op_params))[1];
+    const int lp1 = ((const int*)(dst->op_params))[2];
+    const int rp1 = ((const int*)(dst->op_params))[3];
+    const int lp2 = ((const int*)(dst->op_params))[4];
+    const int rp2 = ((const int*)(dst->op_params))[5];
+    const int lp3 = ((const int*)(dst->op_params))[6];
+    const int rp3 = ((const int*)(dst->op_params))[7];
 
     cl_kernel kernel = backend_ctx->kernel_pad;
 
-    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra_src0->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  &off_src0));
-    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extra_dst->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &off_dst));
-    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),       &s_ne0));
-    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),       &s_ne1));
-    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),       &s_ne2));
-    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),       &d_ne0));
-    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int),       &d_ne1));
-    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int),       &d_ne2));
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),    &extra_src0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),  &off_src0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),    &extra_dst->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),  &off_dst));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),       &s_ne0));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),       &s_ne1));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),       &s_ne2));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),       &s_ne3));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),  &s_nb0));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),  &s_nb1));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),  &s_nb2));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),  &s_nb3));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),       &d_ne0));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),       &d_ne1));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),       &d_ne2));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),       &d_ne3));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),  &d_nb0));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),  &d_nb1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),  &d_nb2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),  &d_nb3));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),       &lp0));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),       &rp0));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),       &lp1));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),       &rp1));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),       &lp2));
+    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),       &rp2));
+    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int),       &lp3));
+    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int),       &rp3));
 
     size_t lws0 = 64;
     size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0;
 
-    size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 };
+    size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 };
     size_t local_work_size[]  = { lws0, 1, 1 };
 
     size_t * local_work_size_ptr = local_work_size;
index 747fa7febcc74c54f665c7da6c09884c3a21d3b0..31fb7ccd3b0817dd3a63cc20f3937684b2e49bcd 100644 (file)
@@ -1,30 +1,39 @@
 kernel void kernel_pad(
-        global const void * src0_ptr,
-        ulong src0_offset,
-        global void * dst_ptr,
-        ulong dst_offset,
-        int s_ne0, int s_ne1, int s_ne2,
-        int d_ne0, int d_ne1, int d_ne2
+        global void * src0,
+        ulong offset0,
+        global void * dst,
+        ulong offsetd,
+        int ne00, int ne01, int ne02, int ne03,
+        ulong nb00, ulong nb01, ulong nb02, ulong nb03,
+        int ne0, int ne1, int ne2, int ne3,
+        ulong nb0, ulong nb1, ulong nb2, ulong nb3,
+        int lp0, int rp0,
+        int lp1, int rp1,
+        int lp2, int rp2,
+        int lp3, int rp3
 ) {
-    global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset);
-    global float * dst = (global float *)((global char *)dst_ptr + dst_offset);
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst  = (global float*)((global char*)dst  + offsetd);
 
-    int nidx   = get_global_id(0);
-    int idx_d1 = get_group_id(1);
-    int idx_d2 = get_group_id(2);
+    int i0 = get_global_id(0);
+    int i1 = get_group_id(1);
+    int i2 = get_group_id(2) % ne2;
+    int i3 = get_group_id(2) / ne2;
 
-    if (nidx >= d_ne0) {
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
         return;
     }
 
-    int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1;
+    uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
+    uint dst_idx  =         i3*nb3  +         i2*nb2  +         i1*nb1  +         i0*nb0;
 
-    bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2);
+    global float * src0_ptr = (global float *)((global char *)src0 + src0_idx);
+    global float * dst_ptr  = (global float *)((global char *)dst  + dst_idx);
 
-    if (in_src_bounds) {
-        int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1;
-        dst[dst_el_offset] = src0[src_el_offset];
-    } else {
-        dst[dst_el_offset] = 0.0f;
-    }
+    bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) &&
+                         (i1 >= lp1 && i1 < ne1 - rp1) &&
+                         (i2 >= lp2 && i2 < ne2 - rp2) &&
+                         (i3 >= lp3 && i3 < ne3 - rp3);
+
+    *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f;
 }