]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: add general Q6_K mm and Q4_K mv (#19347)
authorlhez <redacted>
Wed, 11 Feb 2026 18:33:13 +0000 (10:33 -0800)
committerGitHub <redacted>
Wed, 11 Feb 2026 18:33:13 +0000 (10:33 -0800)
* opencl: add general q6_k mm

* opencl: refine condition for q6_K mm

* opencl: add general q4_K mv

* opencl: fix whitespace

ggml/src/ggml-opencl/CMakeLists.txt
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl [new file with mode: 0644]

index fa5fadd112b756bbca4b94e1adf3ca3953d09898..b6094fb68b0a322e71de577d9214b0e8c25aa7ab 100644 (file)
@@ -85,6 +85,7 @@ set(GGML_OPENCL_KERNELS
     mul_mv_q4_0_f32_8x_flat
     mul_mv_q4_0_f32_1d_8x_flat
     mul_mv_q4_0_f32_1d_16x_flat
+    mul_mv_q4_k_f32
     mul_mv_q6_k_f32
     mul_mv_q6_k_f32_flat
     mul_mv_q8_0_f32
@@ -101,6 +102,7 @@ set(GGML_OPENCL_KERNELS
     mul_mm_f32_f32_l4_lm
     mul_mm_f16_f32_l4_lm
     mul_mm_q8_0_f32_l4_lm
+    mul_mm_q6_k_f32_l4_lm
     mul_mm_q8_0_f32_8x4
     gemv_noshuffle_general_q8_0_f32
     mul
index 508b2b8f037c32ea96f0d4b64f4d0c80638de61f..40474c193bb06cf44b5103e4801dd7b43d71a906 100644 (file)
@@ -532,6 +532,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_restore_block_q4_0_noshuffle;
     cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
     cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
+    cl_kernel kernel_mul_mv_q4_K_f32;
     cl_kernel kernel_mul_mv_q6_K_f32;
     cl_kernel kernel_mul_mv_q6_K_f32_flat;
     cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
@@ -564,6 +565,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_mul_mm_f32_f32_l4_lm;
     cl_kernel kernel_mul_mm_f16_f32_l4_lm;
     cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
+    cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
 
     std::vector<ProfilingInfo> profiling_info;
 
@@ -1117,6 +1119,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mv_q4_k_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mv_q4_k_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
     // mul_mv_q6_k_f32
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1358,6 +1377,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mm_q6_k_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_q6_k_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
     // mul_mm_f16_f32_kq_kqv
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -3364,6 +3400,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             } else if (op->src[0]->type == GGML_TYPE_F32) {
                 return op->src[1]->type == GGML_TYPE_F32;
             } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
+                       op->src[0]->type == GGML_TYPE_Q4_K ||
                        op->src[0]->type == GGML_TYPE_Q6_K) {
                 return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
             } else if (op->src[0]->type == GGML_TYPE_Q8_0) {
@@ -8927,6 +8964,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
                 backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
                 return;
             }
+            case GGML_TYPE_Q6_K: {
+                if (ne11 < 32) {
+                    break;
+                }
+                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+                    break;
+                }
+
+                kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q6_K->ql));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q6_K->qh));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q6_K->s));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra0_q6_K->d));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
             default:
                 break;
         }
@@ -9262,7 +9343,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
         }
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q4_K: {
+            kernel = backend_ctx->kernel_mul_mv_q4_K_f32;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 16;
+                nth1 = 1;
+                ndst = 4;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+                ndst = 4;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),     &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(int),        &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),     &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(int),        &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),     &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),        &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),        &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),        &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),   &nb01));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),   &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),   &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),        &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),   &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),   &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),   &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),        &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),        &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),        &r2));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),        &r3));
+            break;
+        }
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_Q6_K:
 #ifdef GGML_OPENCL_SOA_Q
@@ -9424,7 +9540,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
 
         backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
     } else if (src0t == GGML_TYPE_Q4_K) {
-        GGML_ASSERT(false && "not implemented");
+        size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
     } else if (src0t == GGML_TYPE_Q3_K) {
         GGML_ASSERT(false && "not implemented");
     } else if (src0t == GGML_TYPE_Q5_K) {
diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl
new file mode 100644 (file)
index 0000000..3602c92
--- /dev/null
@@ -0,0 +1,158 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 2
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 32
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_q6_k_f32_l4_lm(
+    global uchar * src0_ql,
+    global uchar * src0_qh,
+    global char  * src0_s,
+    global half  * src0_d,
+    global float4 * src1,
+    ulong offset1,
+    global float  * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst  = (global float *)((global char*)dst  + offsetd);
+
+    local float buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    float cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            if (ir*BM + loadc_a + l < ne01) {
+                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+
+                int ib = idx / 128;                  // 2 values per idx
+                int iqs = idx % 128;                 // 0..127
+
+                int n = iqs / 64;                    // 0,1
+                int b = (iqs % 64) / 32;             // 0,1
+                int is_b = (iqs % 16) / 8;           // 0,1
+                int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
+                int is = 8 * n + qhshift + is_b;     // 0..15
+                int qsi = n * 64 + (iqs % 32) * 2;   // 0,2,4..126
+                int qhi = n * 32 + (iqs % 16) * 2;   // 0,2,4..62
+
+                float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
+
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
+            } else {
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
+            }
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            if (ic*BN + loadc_b + l < ne11) {
+                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            } else {
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
+            }
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl
new file mode 100644 (file)
index 0000000..71ab989
--- /dev/null
@@ -0,0 +1,180 @@
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+//------------------------------------------------------------------------------
+// block_q4_K
+//------------------------------------------------------------------------------
+#define QK_K            256
+#define K_SCALE_SIZE    12
+
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+typedef struct {
+    half d;    // super-block scale for quantized scales
+    half dmin; // super-block scale for quantized mins
+
+    uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uchar qs[QK_K/2];           // 4-bit quants
+} block_q4_K;
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // number of rows each SIMD group works on
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // SIMD group size
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+#undef  BLOCK_STRIDE
+// number of (super) blocks each subgroup processes
+// each thread in a subgroup processes a block (32 weights)
+#define BLOCK_STRIDE (N_SIMDWIDTH/8)
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_q4_K_f32(
+        global char * src0,
+        int offset0,
+        global char * src1,
+        int offset1,
+        global char * dst,
+        int offsetd,
+        int ne00,
+        int ne01,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne12,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
+
+    ushort kmask1 = 0x3f3f;
+    ushort kmask2 = 0x0f0f;
+    ushort kmask3 = 0xc0c0;
+
+    int ix = get_sub_group_local_id()/8;  // super block index
+    int it = get_sub_group_local_id()%8;  // block index (inside super block)
+    int iq = it/4;     // 0 or 1 - first or second half of the super block
+    int ir = it%4;     // 0...3 - block index in the half super block
+
+    int nb = ne00/QK_K;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    int offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
+    global float      * y = (global float      *) (src1 + offset_src1);
+
+    float yl[16];
+    float yh[16];
+    float sumf[N_DST] = {0.f};
+    float all_sum;
+
+    global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+    ushort  sc16[4];
+    uchar * sc8 = (uchar *)sc16;
+
+    for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+0] = y4[i+0];
+            sumy.s0 += yl[i+0];
+
+            yl[i+8] = y4[i+32];
+            sumy.s1 += yl[i+8];
+
+            yh[i+0] = y4[i+128];
+            sumy.s2 += yh[i+0];
+
+            yh[i+8] = y4[i+160];
+            sumy.s3 += yh[i+8];
+        }
+
+        global ushort * sc = (global ushort *)x[ib].scales + iq;
+        global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
+        global half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+            sc16[0] = sc[0] & kmask1;
+            sc16[1] = sc[2] & kmask1;
+            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+            global ushort * q2 = q1 + 32;
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
+                acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
+                acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
+                acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
+                acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
+                acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
+                acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
+                acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
+                                 (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
+                                 (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
+                                 (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
+
+            q1 += nb01/2;
+            sc += nb01/2;
+            dh += nb01/2;
+        }
+
+        y4 += BLOCK_STRIDE * QK_K;
+    }
+
+    global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = sub_group_reduce_add(sumf[row]);
+        if (first_row + row < ne01) {
+            if (get_sub_group_local_id() == 0) {
+                dst_f32[first_row + row] = all_sum;
+            }
+        }
+    }
+}