]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
opencl: add q8_0 mm support (llama/16469)
authorlhez <redacted>
Wed, 15 Oct 2025 17:51:04 +0000 (10:51 -0700)
committerGeorgi Gerganov <redacted>
Tue, 21 Oct 2025 15:14:33 +0000 (18:14 +0300)
* opencl: add mm_q8_0_f32

* opencl: fix data loading for incomplete tile

* opencl: use q8_0 mm for larger matrix

* opencl: add some tests to cover the path

src/ggml-opencl/CMakeLists.txt
src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl
src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl
src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl [new file with mode: 0644]
tests/test-backend-ops.cpp

index 7e6c843846708dae0fb8d82a3720d64e1642cb8c..6f6bba55e2805138b316d5fd396f5c391e505510 100644 (file)
@@ -93,6 +93,7 @@ set(GGML_OPENCL_KERNELS
     mul_mv_id_mxfp4_f32_flat
     mul_mm_f32_f32_l4_lm
     mul_mm_f16_f32_l4_lm
+    mul_mm_q8_0_f32_l4_lm
     mul
     norm
     relu
index 0693d38d80af66fd27651b9a421be7fec4541d09..2ec896fd0e896a55307b9284a226faff5cfe04f7 100644 (file)
@@ -408,6 +408,7 @@ struct ggml_backend_opencl_context {
     cl_program program_mul_mv_id_mxfp4_f32_flat;
     cl_program program_mul_mm_f32_f32_l4_lm;
     cl_program program_mul_mm_f16_f32_l4_lm;
+    cl_program program_mul_mm_q8_0_f32_l4_lm;
 
     cl_kernel kernel_add, kernel_add_row, kernel_add_f16, kernel_add_row_f16;
     cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
@@ -480,6 +481,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
     cl_kernel kernel_mul_mm_f32_f32_l4_lm;
     cl_kernel kernel_mul_mm_f16_f32_l4_lm;
+    cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
 
     std::vector<ProfilingInfo> profiling_info;
 
@@ -1191,6 +1193,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mm_q8_0_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_q8_0_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl");
+#endif
+        backend_ctx->program_mul_mm_q8_0_f32_l4_lm =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // mul
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -6961,6 +6979,44 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
                 backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
                 return;
             }
+            case GGML_TYPE_Q8_0: {
+                if (ne11 < 32) {
+                    break;
+                }
+                kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q8_0->q));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q8_0->d));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
             default:
                 break;
         }
index 9599a0e1572629ad3510b14cd1360f418d59eb63..1a1bfe144f6109c4bcf62f6cf56d436786565f65 100644 (file)
@@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
 
     for (int block = 0; block < ne00; block += BK) {
         for (int l = 0; l < BM; l += loadstride_a) {
+            if (loadc_a + l < ne01) {
             const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
-            buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
-            buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
-            buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
-            buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
+                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
+                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
+            } else {
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0h;
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0h;
+                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0h;
+                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0h;
+            }
         }
 
         for (int l = 0; l < BN; l += loadstride_b) {
-            const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
-            buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
-            buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
-            buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
-            buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            if (loadc_b + l < ne11) {
+                const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            } else {
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0h;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0h;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0h;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0h;
+            }
         }
 
         barrier(CLK_LOCAL_MEM_FENCE);
index 58c5178e39cc8c5f8cd488d7d0e4d9181b0ce117..39a5d4868ffaaa2e33bea15e26209c3dc0ee9435 100644 (file)
@@ -79,19 +79,33 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
 
     for (int block = 0; block < ne00; block += BK) {
         for (int l = 0; l < BM; l += loadstride_a) {
-            const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
-            buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
-            buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
-            buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
-            buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
+            if (loadc_a + l < ne01) {
+                const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
+                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
+                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
+            } else {
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
+            }
         }
 
         for (int l = 0; l < BN; l += loadstride_b) {
-            const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
-            buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
-            buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
-            buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
-            buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            if (loadc_b + l < ne11) {
+                const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            } else {
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
+            }
         }
 
         barrier(CLK_LOCAL_MEM_FENCE);
diff --git a/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl b/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl
new file mode 100644 (file)
index 0000000..fd47e8a
--- /dev/null
@@ -0,0 +1,154 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 4
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 32
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_q8_0_f32_l4_lm(
+    global char4  * src0_q,
+    global half   * src0_d,
+    global float4 * src1,
+    ulong offset1,
+    global float  * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst  = (global float *)((global char*)dst  + offsetd);
+
+    local float buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    float cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            if (loadc_a + l < ne01) {
+                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+                int ib  = idx / 8;
+                int iqs = idx % 8;
+
+                float d = (float)src0_d[ib];
+                global char4 * qs = src0_q + ib*8 + iqs;
+                char4 q = *qs;
+                float4 v = convert_float4(q)*d;
+
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0;
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1;
+                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2;
+                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3;
+            } else {
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
+            }
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            if (loadc_b + l < ne11) {
+                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            } else {
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
+            }
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}
index 3a2876c808ffc2e57aa875189702e4159252f5c8..d5c5a2a6656eec87b309f6825517828d1d2a80a4 100644 (file)
@@ -6365,6 +6365,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
+#if 0
+    {
+        // Test paths in OpenCL
+        std::vector<int> ns = {32, 64, 128, 256, 512, 1024, 4096};
+        std::vector<int> ks = {896, 1536, 4096};
+        for (auto n : ns) {
+            for (auto k : ks) {
+                test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 1024, n, k, {1, 1}, {1, 1}));
+            }
+        }
+    }
+#endif
+
 #if 1
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {