]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
opencl: transposed gemm/gemv moe kernel with mxfp4,f32 (llama/16602)
authorShawn Gu <redacted>
Sat, 18 Oct 2025 00:55:32 +0000 (17:55 -0700)
committerGeorgi Gerganov <redacted>
Tue, 21 Oct 2025 15:14:33 +0000 (18:14 +0300)
* opencl: transposed gemm/gemv moe kernel with mxfp4,f32

* add restore kernel for moe transpose

* fix trailing whitespaces

* resolve compilation warnings

src/ggml-opencl/CMakeLists.txt
src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/cvt.cl
src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl [new file with mode: 0644]
src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl [new file with mode: 0644]

index 6f6bba55e2805138b316d5fd396f5c391e505510..d3d97f375e8f334ffb8832315117426b020e9653 100644 (file)
@@ -91,6 +91,8 @@ set(GGML_OPENCL_KERNELS
     mul_mv_id_q8_0_f32_flat
     mul_mv_id_mxfp4_f32
     mul_mv_id_mxfp4_f32_flat
+    gemm_moe_mxfp4_f32
+    gemv_moe_mxfp4_f32
     mul_mm_f32_f32_l4_lm
     mul_mm_f16_f32_l4_lm
     mul_mm_q8_0_f32_l4_lm
index 2ec896fd0e896a55307b9284a226faff5cfe04f7..d9876e697aae70fc8667288c7f9f4f34dd32dff3 100644 (file)
@@ -402,6 +402,7 @@ struct ggml_backend_opencl_context {
     cl_program program_conv_2d_f32;
     cl_program program_conv_2d_f16_f32;
     cl_program program_tsembd;
+    cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
     cl_program program_mul_mv_id_q4_0_f32_8x_flat;
     cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
     cl_program program_mul_mv_id_mxfp4_f32;
@@ -452,7 +453,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_mul_mat_f16_f32_tiled;
     cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
     cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
-    cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
+    cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
     cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
     cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
     cl_kernel kernel_convert_block_q4_0_noshuffle;
@@ -475,6 +476,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_conv_2d_f32;
     cl_kernel kernel_conv_2d_f16_f32;
     cl_kernel kernel_timestep_embedding;
+    cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
     cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
     cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
     cl_kernel kernel_mul_mv_id_mxfp4_f32;
@@ -559,14 +561,14 @@ struct ggml_backend_opencl_context {
 
         fprintf(ftrace, "[\n");
         for (const ProfilingInfo & info : profiling_info) {
-            fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
+            fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n",
                 info.kernel_name.c_str(), info.cmd_queued/1000);
-            fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
+            fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Host\"},\n",
                 info.kernel_name.c_str(), info.cmd_submit/1000);
 
-            fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
+            fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n",
                 info.kernel_name.c_str(), info.cmd_start/1000);
-            fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
+            fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %llu, \"pid\": \"\", \"tid\": \"Device\"},\n",
                 info.kernel_name.c_str(), info.cmd_end/1000);
         }
         fclose(ftrace);
@@ -777,6 +779,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         CL_CHECK((backend_ctx->kernel_convert_block_q4_0  = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
         CL_CHECK((backend_ctx->kernel_restore_block_q4_0  = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
         CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
+        CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
+        CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
         CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
         CL_CHECK((backend_ctx->kernel_convert_block_q8_0  = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
         CL_CHECK((backend_ctx->kernel_restore_block_q8_0  = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
@@ -1991,6 +1995,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
         GGML_LOG_CONT(".");
     }
+
+    std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
+            " -cl-mad-enable "
+            " -cl-fast-relaxed-math";
+
+    // gemv_moe_mxfp4_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "gemv_moe_mxfp4_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl");
+#endif
+        backend_ctx->program_gemv_moe_mxfp4_f32 =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
+    // gemm_moe_mxfp4_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "gemm_moe_mxfp4_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl");
+#endif
+        backend_ctx->program_gemm_moe_mxfp4_f32 =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err));
+        GGML_LOG_CONT(".");
+    }
 #endif // GGML_OPENCL_USE_ADRENO_KERNELS
     GGML_LOG_CONT("\n");
 }
@@ -3299,6 +3339,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
             tensor->ne[2] == 1 && tensor->ne[3] == 1;
 }
 
+inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
+    GGML_UNUSED(backend_ctx);
+    int ne01 = tensor->ne[1];
+    return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
+}
+
 static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
 
@@ -3601,14 +3647,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
             CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
         CL_CHECK(err);
 
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        if (use_adreno_moe_kernels(backend_ctx, tensor)) {
+            cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
+
+            int ne00 = tensor->ne[0];
+            int ne01 = tensor->ne[1];
+            int ne02 = tensor->ne[2];
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));
+
+            size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
+            size_t local_work_size[3] = {64, 2, 1};
+
+            cl_event evt;
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+            CL_CHECK(clWaitForEvents(1, &evt));
+            CL_CHECK(clReleaseMemObject(data_device));
+            tensor->extra = extra;
+
+            return;
+        }
+#endif
         cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
 
         CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
         CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
         CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
 
-        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
-        size_t local_work_size[] = {64, 1, 1};
+        size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[3] = {64, 1, 1};
 
         cl_event evt;
         CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
@@ -3624,7 +3695,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
             { extra->q }
         };
         extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
-
         tensor->extra = extra;
 
         return;
@@ -3751,6 +3821,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
             ggml_nbytes(tensor), NULL, &err);
         CL_CHECK(err);
 
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        if (use_adreno_moe_kernels(backend_ctx, tensor)) {
+            cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
+
+            int ne00 = tensor->ne[0];
+            int ne01 = tensor->ne[1];
+            int ne02 = tensor->ne[2];
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
+
+            size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
+            size_t local_work_size[3] = {64, 2, 1};
+
+            cl_event evt;
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
+                global_work_size, local_work_size, 0, NULL, &evt));
+            CL_CHECK(clWaitForEvents(1, &evt));
+            CL_CHECK(clEnqueueReadBuffer(
+                queue, data_device, CL_TRUE, offset,
+                size, data, 0, NULL, NULL));
+            CL_CHECK(clReleaseMemObject(data_device));
+            return;
+        }
+#endif
         cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
         CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
         CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
@@ -7553,6 +7650,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
     const int ne21 = src2->ne[1];
 
     const cl_ulong nb21 = src2->nb[1];
+    const cl_ulong nb20 = src2->nb[0];
 
     const int ne0 = dst->ne[0];
     const int ne1 = dst->ne[1];
@@ -7692,6 +7790,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
             break;
         }
         case GGML_TYPE_MXFP4: {
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+            if (use_adreno_moe_kernels(backend_ctx, src0)) {
+                cl_int status;
+
+                size_t local_size[3] = {64, 2, 1};
+                size_t global_size[3] = {64, 2, 1};
+
+                cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
+
+                int tile_size = 320;
+                if (ne12 == 1) { // for gemv
+                    kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
+
+                    // create a sub_buffer for src2
+                    cl_buffer_region region;
+                    region.origin = offset2;
+                    region.size = ne20 * ne21 * sizeof(int);
+                    buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
+                    CL_CHECK(status);
+
+                    // set thread grid
+                    global_size[0] = static_cast<size_t>(ne01);
+                    global_size[1] = 4;
+                    global_size[2] = static_cast<size_t>(ne20);
+                    local_size[1] = 4;
+                } else { // for gemm
+                    kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
+
+                    // preprocess router table
+                    int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
+                    void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
+                    void * host_src2 = malloc(ne21 * nb21);
+                    CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
+                    int total_experts = nb21 / nb20;
+                    int out_idx = 0;
+                    for (int i_expert = 0; i_expert < ne02; i_expert++) {
+                        for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
+                            for (int j = 0; j < ne21; j++) {
+                                for (int i = 0; i < ne20; i++) {
+                                    int expert = ((int *)host_src2)[j * total_experts + i];
+                                    if (i_expert == expert) {
+                                        ((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
+                                        ((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
+                                        ((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
+                                        ((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
+                                        out_idx += 4;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                    buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
+                    CL_CHECK(status);
+
+                    // set thread grid
+                    global_size[0] = static_cast<size_t>(tile_size);
+                    global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
+                }
+
+                // create a sub_buffer for src1
+                cl_buffer_region region;
+                region.origin = offset1;
+                region.size = ne10 * ne11 * ne12 * sizeof(float);
+                src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, &region, &status);
+                CL_CHECK(status);
+
+                // create image for src1
+                cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
+                cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
+                buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
+                CL_CHECK(status);
+
+                // Set kernel args
+                int arg_idx = 0;
+                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &extra0_mxfp4->q));
+                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &extra0_mxfp4->e));
+                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &buf_src1_image));
+                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &buf_src2));
+                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem),    &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong),  &offsetd));
+                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int),       &ne00));
+                CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int),       &ne01));
+                if (ne12 == 1) {
+                    CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int),       &ne11));
+                } else {
+                    CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int),       &tile_size));
+                }
+
+                // launch kernel
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
+
+                // deallocate sub buffers and images
+                CL_CHECK(clReleaseMemObject(src1_sub_buffer));
+                CL_CHECK(clReleaseMemObject(buf_src1_image));
+                CL_CHECK(clReleaseMemObject(buf_src2));
+                return;
+            } // else fallback to generic kernel
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
 #ifdef GGML_OPENCL_SOA_Q
             kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
 
index 045300eb3a53778a45a890bd16462fdae7ddbce6..b26f9c5fb2a310ac89d548eee6d3b5f3c6cd8324 100644 (file)
@@ -147,6 +147,27 @@ kernel void kernel_convert_block_mxfp4(
     }
 }
 
+kernel void kernel_convert_block_mxfp4_trans(
+    global struct block_mxfp4 * src0,
+    __global uint4 * dst_q,
+    __global uchar * dst_e,
+    uint ne00,
+    uint ne01
+) {
+    int i00 = get_global_id(1);
+    uint i01 = get_global_id(0);
+    uint i02 = get_global_id(2);
+
+    uint ne00_blk = ne00 / QK_MXFP4;
+    uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
+    uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
+
+    global struct block_mxfp4 * b = src0 + src_blk_offset;
+
+    dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
+    dst_e[dst_blk_offset] = b->e;
+}
+
 kernel void kernel_restore_block_mxfp4(
     global uchar * src_q,
     global half  * src_e,
@@ -162,6 +183,27 @@ kernel void kernel_restore_block_mxfp4(
     }
 }
 
+kernel void kernel_restore_block_mxfp4_trans(
+    __global uint4 * src_q,
+    __global uchar * src_e,
+    global struct block_mxfp4 * dst,
+    uint ne00,
+    uint ne01
+) {
+    int i00 = get_global_id(1);
+    uint i01 = get_global_id(0);
+    uint i02 = get_global_id(2);
+
+    uint ne00_blk = ne00 / QK_MXFP4;
+    uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
+    uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
+
+    global struct block_mxfp4 * b = dst + dst_blk_offset;
+
+    ((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
+    b->e = src_e[src_blk_offset];
+}
+
 //------------------------------------------------------------------------------
 // block_q8_0
 //------------------------------------------------------------------------------
diff --git a/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl b/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl
new file mode 100644 (file)
index 0000000..3917aa3
--- /dev/null
@@ -0,0 +1,162 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+#define QK_MXFP4 32
+#define N_SIMDGROUP 2
+#define SIMDGROUP_WIDTH 64
+
+static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
+    ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
+    fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
+    fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
+    fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
+    fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
+
+    bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
+    bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
+    bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
+    bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
+
+    fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
+    fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
+    fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
+    fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
+
+    sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
+    sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
+    sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
+    sign_b.hi = fp4x8.s0 & 0x8000;
+
+    fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
+    fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
+
+    ushort2 fp16_packed_a_1, fp16_packed_b_1;
+    fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
+    fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
+    fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
+    fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
+
+    bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
+    bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
+    bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
+    bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
+
+    fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
+    fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
+    fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
+    fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
+
+    sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
+    sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
+    sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
+    sign_b.hi = fp4x8.s1 & 0x8000;
+
+    fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
+    fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
+
+    return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
+}
+
+static inline float e8m0_to_fp32(uchar x) {
+    int bits;
+    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
+    return as_float(bits);
+}
+
+
+__attribute__((qcom_reqd_sub_group_size("half")))
+__kernel void kernel_gemm_moe_mxfp4_f32(
+    __global uint4 * src0_q,
+    __global uchar * src0_e,
+    __read_only image1d_buffer_t src1,
+    __global ushort4 * src2,
+    __global float * dst,
+    ulong         offsetd,
+    int           ne00,
+    int           ne01,
+    int           tile_size
+) {
+    uint i01  = get_global_id(0);
+    uint i20  = get_global_id(2);
+    uint sgid = get_local_id(1);
+    uint slid = get_sub_group_local_id();
+
+    ushort4 router = src2[i20];
+    ushort expert_id = router.x;
+    ushort i11 = router.y;
+    ushort i1 = router.z;
+    ushort tile_id = router.w;
+
+    if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
+        return;
+    }
+
+    uint expert_offset = expert_id * ne00 * ne01 / 32;
+    uint tile_offset = expert_offset + tile_id * tile_size + i01;
+
+    __private float sum = 0.0f; // each thread calculate partial sum of one output
+
+    // loop along ne00 in block granularity, skip 4 blocks every iter
+    for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
+        // load one block of q
+        uint4 regQ = src0_q[tile_offset + ib00 * ne01];
+        // convert 8 fp4 to fp16
+        half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
+
+        uint offset = i11 * ne00 / 4 + ib00 * 8;
+        float4 shared_y4;
+        shared_y4 = read_imagef(src1, (offset + 0));
+        float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+        shared_y4 = read_imagef(src1, (offset + 4));
+        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
+
+        shared_y4 = read_imagef(src1, (offset + 1));
+        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+        shared_y4 = read_imagef(src1, (offset + 5));
+        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
+
+        shared_y4 = read_imagef(src1, (offset + 2));
+        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+        shared_y4 = read_imagef(src1, (offset + 6));
+        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
+
+        shared_y4 = read_imagef(src1, (offset + 3));
+        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+        shared_y4 = read_imagef(src1, (offset + 7));
+        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+        uchar regE = src0_e[tile_offset + ib00 * ne01];
+        sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
+    }
+
+    // reduction in local memory, assumes #subgroups=4
+    __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
+    if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
+    // if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
+    // if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
+    // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
+    // if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
+
+    // 1 outputs per thread in subgroup 0
+    if (sgid == 0) {
+        dst = dst + (offsetd >> 2);
+        dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
+    }
+
+}
diff --git a/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl b/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl
new file mode 100644 (file)
index 0000000..b4b1e51
--- /dev/null
@@ -0,0 +1,156 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+#define QK_MXFP4 32
+#define N_SIMDGROUP 4
+#define SIMDGROUP_WIDTH 64
+
+static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
+    ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
+    fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
+    fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
+    fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
+    fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
+
+    bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
+    bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
+    bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
+    bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
+
+    fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
+    fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
+    fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
+    fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
+
+    sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
+    sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
+    sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
+    sign_b.hi = fp4x8.s0 & 0x8000;
+
+    fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
+    fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
+
+    ushort2 fp16_packed_a_1, fp16_packed_b_1;
+    fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
+    fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
+    fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
+    fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
+
+    bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
+    bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
+    bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
+    bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
+
+    fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
+    fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
+    fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
+    fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
+
+    sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
+    sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
+    sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
+    sign_b.hi = fp4x8.s1 & 0x8000;
+
+    fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
+    fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
+
+    return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
+}
+
+static inline float e8m0_to_fp32(uchar x) {
+    int bits;
+    bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
+    return as_float(bits);
+}
+
+
+__attribute__((qcom_reqd_sub_group_size("half")))
+__kernel void kernel_gemv_moe_mxfp4_f32(
+    __global uint4 * src0_q,
+    __global uchar * src0_e,
+    __read_only image1d_buffer_t src1,
+    __global uint * src2,
+    __global float * dst,
+    ulong         offsetd,
+    int           ne00,
+    int           ne01,
+    int           ne11
+) {
+    uint i01  = get_global_id(0);
+    uint i20  = get_global_id(2);
+    uint sgid = get_local_id(1);
+    uint slid = get_sub_group_local_id();
+
+    uint i11 = i20 % ne11;
+
+    uint expert_id = src2[i20];
+    uint expert_offset = expert_id * ne00 * ne01 / 32;
+
+    __private float sum = 0.0f; // each thread calculate partial sum of one output
+
+    // loop along ne00 in block granularity, skip 4 blocks every iter
+    for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
+
+        // load one block of q
+        uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01];
+
+        uint offset = i11 * ne00 / 4 + ib00 * 8;
+
+        half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
+
+        float4 shared_y4;
+        shared_y4 = read_imagef(src1, (offset + 0));
+        float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+        shared_y4 = read_imagef(src1, (offset + 4));
+        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
+
+        shared_y4 = read_imagef(src1, (offset + 1));
+        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+        shared_y4 = read_imagef(src1, (offset + 5));
+        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
+
+        shared_y4 = read_imagef(src1, (offset + 2));
+        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+        shared_y4 = read_imagef(src1, (offset + 6));
+        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+
+        fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
+
+        shared_y4 = read_imagef(src1, (offset + 3));
+        acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
+
+        shared_y4 = read_imagef(src1, (offset + 7));
+        acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
+
+        uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
+        sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
+    }
+
+    // reduction in local memory, assumes #subgroups=4
+    __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
+    if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
+    if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
+    if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
+    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
+    if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
+
+    // 1 outputs per thread in subgroup 0
+    if (sgid == 0) {
+        dst = dst + (offsetd >> 2);
+        dst[i01 + i20 * ne01] = sum;
+    }
+
+}