]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: General GEMV fusion (llama/16715)
authorAman Gupta <redacted>
Sun, 26 Oct 2025 11:28:04 +0000 (19:28 +0800)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
src/ggml-cuda/common.cuh
src/ggml-cuda/convert.cuh
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/mmvf.cu
src/ggml-cuda/mmvf.cuh
src/ggml-cuda/mmvq.cu
src/ggml-cuda/mmvq.cuh
src/ggml-cuda/unary.cu
src/ggml-cuda/unary.cuh
tests/test-backend-ops.cpp

index 41ff89c4d69227c3030b752610d6fe89ab39e118..1af23588301dd4c6d5ccad352ebb7fa96d3be47e 100644 (file)
@@ -1005,3 +1005,16 @@ struct ggml_backend_cuda_context {
         return pool(device);
     }
 };
+
+struct ggml_cuda_mm_fusion_args_host {
+    const ggml_tensor * x_bias = nullptr;
+    const ggml_tensor * gate = nullptr;
+    const ggml_tensor * gate_bias = nullptr;
+    ggml_glu_op glu_op;
+};
+struct ggml_cuda_mm_fusion_args_device {
+    const void * x_bias = nullptr;
+    const void * gate = nullptr;
+    const void * gate_bias = nullptr;
+    ggml_glu_op glu_op;
+};
index ef9e129950c98f010f113ab4e45427cb55ca343e..8a5e08ef667e0ef0094951888baf24b969aa597d 100644 (file)
@@ -1,3 +1,4 @@
+#pragma once
 #include "common.cuh"
 
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
index bc396b521af0751aa27eacba3d01d888588b79aa..19f72975c0ee46605ca96e4a9d07c1184a1bb6f4 100644 (file)
@@ -2007,6 +2007,147 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     }
 }
 
+static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
+                                          const ggml_tensor * ffn_gate,
+                                          const ggml_tensor * glu,
+                                          const ggml_tensor * ffn_up_bias = nullptr,
+                                          const ggml_tensor * ffn_gate_bias = nullptr) {
+    const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
+
+    if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
+        return false;
+    }
+
+    const bool is_mul_mat     = ffn_up->op == GGML_OP_MUL_MAT     && ffn_gate->op == GGML_OP_MUL_MAT     && glu->op == GGML_OP_GLU;
+    const bool is_mul_mat_id  = ffn_up->op == GGML_OP_MUL_MAT_ID  && ffn_gate->op == GGML_OP_MUL_MAT_ID  && glu->op == GGML_OP_GLU;
+
+    GGML_ASSERT(ffn_up && ffn_gate && glu);
+
+    if (!is_mul_mat && !is_mul_mat_id) {
+        return false;
+    }
+
+    const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
+
+    if (has_bias) {
+        if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
+            return false;
+        }
+
+        if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
+            return false;
+        }
+
+        if (expected_bias_op == GGML_OP_ADD) {
+            const bool up_has_mul   = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
+            const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
+            if (!up_has_mul || !gate_has_mul) {
+                return false;
+            }
+        } else { // GGML_OP_ADD_ID
+            if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
+                return false;
+            }
+            if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
+                return false;
+            }
+        }
+    } else {
+        if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
+            return false;
+        }
+    }
+
+    if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
+        !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
+        return false;
+    }
+
+    if (ffn_up->src[1] != ffn_gate->src[1]) {
+        return false;
+    }
+
+    if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
+        return false;
+    }
+
+    static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
+
+    if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
+        return false;
+    }
+
+    if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
+        return false;
+    }
+
+    const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
+                       ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
+
+    //TODO: add support for fusion for split buffers
+    if (split) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
+    ggml_tensor *       src0 = tensor->src[0];
+    ggml_tensor *       src1 = tensor->src[1];
+    const ggml_tensor * dst  = tensor;
+
+    const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
+
+    bool use_mul_mat_vec_f =
+        (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
+        src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+    const int cc      = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
+
+    //we only support fusion for ncols_dst = 1
+    if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
+        return false;
+    }
+
+    if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
+        return false;
+    }
+
+
+    return use_mul_mat_vec_f;
+}
+
+static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
+    ggml_tensor *       src0 = tensor->src[0];
+    ggml_tensor *       src1 = tensor->src[1];
+    const ggml_tensor * dst  = tensor;
+
+    const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
+                                   ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
+                                   src0->view_src;
+
+    bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
+                             dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+
+    // fusion is not universally faster on Pascal
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    if (cc <= GGML_CUDA_CC_PASCAL) {
+        return false;
+    }
+    //we only support fusion for ncols_dst = 1
+    if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
+        return false;
+    }
+
+    if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
+        return false;
+    }
+
+    return use_mul_mat_vec_q;
+}
+
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
 
@@ -2745,7 +2886,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
         }
     }
 
-    if (node->op == GGML_OP_SCALE &&
+    if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
         memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
         return false;
     }
@@ -2854,6 +2995,38 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         }
     }
 
+    std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_GLU };
+    std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
+
+    std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
+    std::initializer_list<enum ggml_op> mul_mat_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_MUL_MAT,    GGML_OP_GLU };
+
+    if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
+                            ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
+
+        const ggml_tensor * ffn_gate      = cgraph->nodes[node_idx];
+        const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
+        const ggml_tensor * ffn_up        = cgraph->nodes[node_idx + 2];
+        const ggml_tensor * ffn_up_bias   = cgraph->nodes[node_idx + 3];
+        const ggml_tensor * glu           = cgraph->nodes[node_idx + 4];
+
+        if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
+            return true;
+        }
+    }
+
+    if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
+                            ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
+
+        const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
+        const ggml_tensor * ffn_up   = cgraph->nodes[node_idx + 1];
+        const ggml_tensor * glu      = cgraph->nodes[node_idx + 2];
+
+        if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
+            return true;
+        }
+    }
+
     if (!ggml_can_fuse(cgraph, node_idx, ops)) {
         return false;
     }
@@ -3004,6 +3177,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                         }
                     }
 
+                    bool fused_mul_mat_vec = false;
+                    int fused_node_count = 0;
+
+                    for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
+                        const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
+
+                        if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
+                            ggml_tensor * glu         = cgraph->nodes[i + 4];
+                            ggml_tensor * gate_bias_n = glu->src[0];
+                            ggml_tensor * up_bias_n   = glu->src[1];
+
+                            //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
+                            ggml_tensor * gate_n      = nullptr;
+                            ggml_tensor * up_n        = nullptr;
+
+                            if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
+                                gate_n = cgraph->nodes[i];
+                                up_n   = cgraph->nodes[i + 2];
+                            } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
+                                gate_n = cgraph->nodes[i + 2];
+                                up_n   = cgraph->nodes[i];
+                            } else {
+                                continue;
+                            }
+
+                            auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
+                                if (op_bias == GGML_OP_ADD) {
+                                    if (bias_node->src[0] == mul_node) {
+                                        return bias_node->src[1];
+                                    }
+                                    if (bias_node->src[1] == mul_node) {
+                                        return bias_node->src[0];
+                                    }
+                                    return (ggml_tensor *) nullptr;
+                                }
+                                GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
+                                GGML_ASSERT(bias_node->src[0] == mul_node);
+                                return bias_node->src[1];
+                            };
+
+                            ggml_tensor * up_bias_tensor   = get_bias_tensor(up_bias_n, up_n, bias_op);
+                            ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
+
+                            if (!up_bias_tensor || !gate_bias_tensor) {
+                                continue;
+                            }
+
+                            const ggml_tensor * src0 = up_n->src[0];
+                            const ggml_tensor * src1 = up_n->src[1];
+                            const ggml_tensor * ids  = up_n->src[2];
+
+                            if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
+                                ggml_cuda_mm_fusion_args_host fusion_data{};
+                                fusion_data.gate      = gate_n->src[0];
+                                fusion_data.x_bias    = up_bias_tensor;
+                                fusion_data.gate_bias = gate_bias_tensor;
+                                fusion_data.glu_op    = ggml_get_glu_op(glu);
+
+                                ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
+                                fused_mul_mat_vec = true;
+                                fused_node_count = 5;
+                                break;
+                            }
+
+                            if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
+                                ggml_cuda_mm_fusion_args_host fusion_data{};
+                                fusion_data.gate      = gate_n->src[0];
+                                fusion_data.x_bias    = up_bias_tensor;
+                                fusion_data.gate_bias = gate_bias_tensor;
+                                fusion_data.glu_op    = ggml_get_glu_op(glu);
+
+                                ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
+                                fused_mul_mat_vec = true;
+                                fused_node_count = 5;
+                                break;
+                            }
+                        } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
+                            ggml_tensor * glu  = cgraph->nodes[i + 2];
+                            ggml_tensor * gate = glu->src[0];
+                            ggml_tensor * up   = glu->src[1];
+
+                            bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
+                                || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
+
+                            if (!ok) continue;
+
+                            const ggml_tensor * src0 = up->src[0];
+                            const ggml_tensor * src1 = up->src[1];
+                            const ggml_tensor * ids  = up->src[2];
+
+                            if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
+                                ggml_cuda_mm_fusion_args_host fusion_data{};
+                                fusion_data.gate   = gate->src[0];
+                                fusion_data.glu_op = ggml_get_glu_op(glu);
+
+                                ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
+                                fused_mul_mat_vec = true;
+                                fused_node_count = 3;
+                                break;
+                            }
+
+                            if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
+                                ggml_cuda_mm_fusion_args_host fusion_data{};
+                                fusion_data.gate   = gate->src[0];
+                                fusion_data.glu_op = ggml_get_glu_op(glu);
+
+                                ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
+                                fused_mul_mat_vec = true;
+                                fused_node_count = 3;
+                                break;
+                            }
+                        }
+                    }
+
+                    if (fused_mul_mat_vec) {
+                        i += fused_node_count - 1;
+                        continue;
+                    }
+
+                    fused_mul_mat_vec = false;
+                    fused_node_count = 0;
+
+                    for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
+                        const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
+
+                        if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
+                            continue;
+                        }
+
+                        ggml_tensor * mm_node   = cgraph->nodes[i];
+                        ggml_tensor * bias_node = cgraph->nodes[i + 1];
+
+                        ggml_tensor * bias_tensor = nullptr;
+                        if (bias_op == GGML_OP_ADD) {
+                            if (bias_node->src[0] == mm_node) {
+                                bias_tensor = bias_node->src[1];
+                            } else if (bias_node->src[1] == mm_node) {
+                                bias_tensor = bias_node->src[0];
+                            } else {
+                                continue;
+                            }
+                        } else {
+                            if (bias_node->src[0] != mm_node) {
+                                continue;
+                            }
+                            bias_tensor = bias_node->src[1];
+                        }
+
+                        const ggml_tensor * src0 = mm_node->src[0];
+                        const ggml_tensor * src1 = mm_node->src[1];
+                        const ggml_tensor * ids  = mm_node->src[2];
+
+                        if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
+                            continue;
+                        }
+
+                        ggml_cuda_mm_fusion_args_host fusion_data{};
+                        fusion_data.x_bias = bias_tensor;
+
+                        if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
+                            ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
+                            fused_mul_mat_vec = true;
+                            fused_node_count = 2;
+                            break;
+                        }
+
+                        if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
+                            ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
+                            fused_mul_mat_vec = true;
+                            fused_node_count = 2;
+                            break;
+                        }
+                    }
+
+                    if (fused_mul_mat_vec) {
+                        i += fused_node_count - 1;
+                        continue;
+                    }
 
                     if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
                         ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
index 57ab839393aa00746779cc311d719ff812e2c5ac..c2c31cdaf231bc7f1a72426f88af6df4a74ccb07 100644 (file)
@@ -1,11 +1,12 @@
 #include "ggml.h"
 #include "common.cuh"
-#include "convert.cuh"
+#include "unary.cuh"
 #include "mmvf.cuh"
+#include "convert.cuh"
 
-template <typename T, typename type_acc, int ncols_dst, int block_size>
+template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
 static __global__ void mul_mat_vec_f(
-        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
         const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
         const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
     y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
     dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
 
+    bool use_gate = false;
+    bool use_bias = false;
+    bool use_gate_bias = false;
+    ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
+    const T * gate_x = nullptr;
+    const float * x_bias = nullptr;
+    const float * gate_bias = nullptr;
+
+    if constexpr (has_fusion) {
+        use_gate = fusion.gate != nullptr;
+        use_bias = fusion.x_bias != nullptr;
+        use_gate_bias = fusion.gate_bias != nullptr;
+        glu_op = fusion.glu_op;
+
+        if (use_gate) {
+            gate_x = static_cast<const T *>(fusion.gate);
+        }
+        if (use_bias) {
+            x_bias = static_cast<const float *>(fusion.x_bias);
+        }
+        if (use_gate_bias) {
+            gate_bias = static_cast<const float *>(fusion.gate_bias);
+            use_gate_bias = use_gate;
+        } else {
+            use_gate_bias = false;
+        }
+    }
+
+    if (use_gate) {
+        gate_x += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
+    }
+    if constexpr (has_fusion) {
+        const int channel_bias = ids ? channel_x : channel_dst;
+        if (use_bias) {
+            x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
+        }
+        if (use_gate_bias) {
+            gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
+        }
+    }
+
     const float2 * y2 = (const float2 *) y;
 
     extern __shared__ char data_mmv[];
     float * buf_iw = (float *) data_mmv;
+    float * buf_iw_gate = nullptr;
+    if constexpr (has_fusion) {
+        buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
+    }
 
     if (block_size > warp_size) {
         if (tid < warp_size) {
             buf_iw[tid] = 0.0f;
+            if constexpr (has_fusion) {
+                if (use_gate) {
+                    buf_iw_gate[tid] = 0.0f;
+                }
+            }
         }
         __syncthreads();
     }
 
     float sumf[ncols_dst] = {0.0f};
+    float sumf_gate[ncols_dst];
+    if constexpr (has_fusion) {
+#pragma unroll
+        for (int j = 0; j < ncols_dst; ++j) {
+            sumf_gate[j] = 0.0f;
+        }
+    }
 
     if constexpr (std::is_same_v<T, float>) {
         const float2 * x2 = (const float2 *) x;
+        const float2 * gate_x2 = nullptr;
+        if constexpr (has_fusion) {
+            if (use_gate) {
+                gate_x2 = (const float2 *) gate_x;
+            }
+        }
 
         for (int col2 = tid; col2 < ncols2; col2 += block_size) {
             const float2 tmpx = x2[col2];
+            float2 tmpx_gate = make_float2(0.0f, 0.0f);
+            if constexpr (has_fusion) {
+                if (use_gate) {
+                    tmpx_gate = gate_x2[col2];
+                }
+            }
 
 #pragma unroll
             for (int j = 0; j < ncols_dst; ++j) {
                 const float2 tmpy = y2[j*stride_col_y2 + col2];
                 ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
                 ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
+
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
+                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
+                    }
+                }
             }
         }
     } else if constexpr (std::is_same_v<T, half>) {
         const half2 * x2 = (const half2 *) x;
+        const half2 * gate_x2 = nullptr;
+        if constexpr (has_fusion) {
+            if (use_gate) {
+                gate_x2 = (const half2 *) gate_x;
+            }
+        }
 
         if (std::is_same_v<type_acc, float>) {
             for (int col2 = tid; col2 < ncols2; col2 += block_size) {
                 const float2 tmpx = __half22float2(x2[col2]);
-
+                float2 tmpx_gate = make_float2(0.0f, 0.0f);
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        tmpx_gate = __half22float2(gate_x2[col2]);
+                    }
+                }
 #pragma unroll
                 for (int j = 0; j < ncols_dst; ++j) {
                     const float2 tmpy = y2[j*stride_col_y2 + col2];
                     ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
                     ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
+
+                    if constexpr (has_fusion) {
+                        if (use_gate) {
+                            ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
+                            ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
+                        }
+                    }
                 }
             }
         } else {
 #ifdef FP16_AVAILABLE
             half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
+            half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
 
             for (int col2 = tid; col2 < ncols2; col2 += block_size) {
                 const half2 tmpx = x2[col2];
-
+                half2 tmpx_gate = make_half2(0.0f, 0.0f);
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        tmpx_gate = gate_x2[col2];
+                    }
+                }
 #pragma unroll
                 for (int j = 0; j < ncols_dst; ++j) {
                     const float2 tmpy = y2[j*stride_col_y2 + col2];
                     sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
+
+                    if constexpr (has_fusion) {
+                        if (use_gate) {
+                            sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
+                        }
+                    }
                 }
             }
 
@@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f(
             for (int j = 0; j < ncols_dst; ++j) {
                 sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
             }
+
+            if constexpr (has_fusion) {
+                if (use_gate) {
+#pragma unroll
+                    for (int j = 0; j < ncols_dst; ++j) {
+                        sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
+                    }
+                }
+            }
 #else
             NO_DEVICE_CODE;
 #endif // FP16_AVAILABLE
@@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f(
 //TODO: add support for ggml_cuda_mad for hip_bfloat162
 #if defined(GGML_USE_HIP)
         const int * x2 = (const int *) x;
+        const int * gate_x2 = nullptr;
+        if constexpr (has_fusion) {
+            if (use_gate) {
+                gate_x2 = (const int *) gate_x;
+            }
+        }
         for (int col2 = tid; col2 < ncols2; col2 += block_size) {
             const int tmpx = x2[col2];
+            int tmpx_gate = 0;
+            if constexpr (has_fusion) {
+                if (use_gate) {
+                    tmpx_gate = gate_x2[col2];
+                }
+            }
 #pragma unroll
             for (int j = 0; j < ncols_dst; ++j) {
                 const float2 tmpy = y2[j*stride_col_y2 + col2];
@@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f(
                 const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
                 ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
                 ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
+
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
+                        const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
+                        ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
+                        ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
+                    }
+                }
             }
         }
 #else
         const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
+        const nv_bfloat162 * gate_x2 = nullptr;
+        if constexpr (has_fusion) {
+            if (use_gate) {
+                gate_x2 = (const nv_bfloat162 *) gate_x;
+            }
+        }
         for (int col2 = tid; col2 < ncols2; col2 += block_size) {
             const nv_bfloat162 tmpx = x2[col2];
+            nv_bfloat162 tmpx_gate;
+            if constexpr (has_fusion) {
+                if (use_gate) {
+                    tmpx_gate = gate_x2[col2];
+                }
+            }
 #pragma unroll
             for (int j = 0; j < ncols_dst; ++j) {
                 const float2 tmpy = y2[j*stride_col_y2 + col2];
                 ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
                 ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
+
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
+                        ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
+                    }
+                }
             }
         }
 #endif
@@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f(
     for (int j = 0; j < ncols_dst; ++j) {
         sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
 
+        if constexpr (has_fusion) {
+            if (use_gate) {
+                sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
+            }
+        }
+
         if (block_size > warp_size) {
             buf_iw[tid/warp_size] = sumf[j];
+            if constexpr (has_fusion) {
+                if (use_gate) {
+                    buf_iw_gate[tid/warp_size] = sumf_gate[j];
+                }
+            }
             __syncthreads();
             if (tid < warp_size) {
                 sumf[j] = buf_iw[tid];
                 sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        sumf_gate[j] = buf_iw_gate[tid];
+                        sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
+                    }
+                }
             }
+
             if (j < ncols_dst) {
                 __syncthreads();
             }
@@ -139,12 +313,70 @@ static __global__ void mul_mat_vec_f(
         return;
     }
 
-    dst[tid*stride_col_dst + row] = sumf[tid];
+    float value = sumf[tid];
+
+    if constexpr (has_fusion) {
+        if (use_bias) {
+            value += x_bias[tid*stride_col_dst + row];
+        }
+
+        if (use_gate) {
+            float gate_value = sumf_gate[tid];
+            if (use_gate_bias) {
+                gate_value += gate_bias[tid*stride_col_dst + row];
+            }
+            switch (glu_op) {
+                case GGML_GLU_OP_SWIGLU:
+                    value *= ggml_cuda_op_silu_single(gate_value);
+                    break;
+                case GGML_GLU_OP_GEGLU:
+                    value *= ggml_cuda_op_gelu_single(gate_value);
+                    break;
+                case GGML_GLU_OP_SWIGLU_OAI: {
+                    value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
+                    break;
+                }
+                default:
+                    break;
+            }
+        }
+    }
+
+    dst[tid*stride_col_dst + row] = value;
+}
+
+template<typename T, typename type_acc, int ncols_dst, int block_size>
+static void mul_mat_vec_f_switch_fusion(
+        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+        const int64_t ncols, const int64_t nrows,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
+
+    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+    if constexpr (ncols_dst == 1) {
+        if (has_fusion) {
+            mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            return;
+       }
+    }
+
+    GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
+
+    mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
+        (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+
 }
 
 template <typename T, typename type_acc, int ncols_dst>
-static void launch_mul_mat_vec_f_cuda(
-        const T * x, const float * y, const int32_t * ids, float * dst,
+void launch_mul_mat_vec_f_cuda(
+        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const int64_t ncols, const int64_t nrows,
         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
@@ -176,57 +408,59 @@ static void launch_mul_mat_vec_f_cuda(
         }
     }
 
-    const int nbytes_shared = warp_size*sizeof(float);
+    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+
+    const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
     const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
-            mul_mat_vec_f<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
         } break;
         case   64: {
-            mul_mat_vec_f<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
         } break;
         case   96: {
-            mul_mat_vec_f<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
         } break;
         case  128: {
-            mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
         } break;
         case  160: {
-            mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
         } break;
         case  192: {
-            mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
         } break;
         case  224: {
-            mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
         } break;
         case  256: {
-            mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -236,7 +470,7 @@ static void launch_mul_mat_vec_f_cuda(
 
 template <typename T, typename type_acc>
 static void mul_mat_vec_f_cuda_switch_ncols_dst(
-        const T * x, const float * y, const int32_t * ids, float * dst,
+        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
@@ -246,49 +480,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
     switch (ncols_dst) {
         case 1:
             launch_mul_mat_vec_f_cuda<T, type_acc, 1>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 2:
             launch_mul_mat_vec_f_cuda<T, type_acc, 2>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 3:
             launch_mul_mat_vec_f_cuda<T, type_acc, 3>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 4:
             launch_mul_mat_vec_f_cuda<T, type_acc, 4>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 5:
             launch_mul_mat_vec_f_cuda<T, type_acc, 5>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 6:
             launch_mul_mat_vec_f_cuda<T, type_acc, 6>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 7:
             launch_mul_mat_vec_f_cuda<T, type_acc, 7>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case 8:
             launch_mul_mat_vec_f_cuda<T, type_acc, 8>
-                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
@@ -300,29 +534,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
 
 template<typename T>
 static void mul_mat_vec_f_cuda(
-        const T * x, const float * y, const int32_t * ids, float * dst,
+        const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
         const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         enum ggml_prec prec, cudaStream_t stream) {
+
     if constexpr(std::is_same_v<T, half>) {
         if (prec == GGML_PREC_DEFAULT) {
             mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
-                (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
-                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             return;
         }
     }
     mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
-        (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
-         nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-         stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+        nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
 }
 
-void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
+    const ggml_cuda_mm_fusion_args_host * fusion) {
     GGML_ASSERT(        src1->type == GGML_TYPE_F32);
     GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
     GGML_ASSERT(         dst->type == GGML_TYPE_F32);
@@ -348,6 +584,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
     const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
     float         *  dst_d =       (float         *)  dst->data;
 
+    ggml_cuda_mm_fusion_args_device fusion_local{};
+
+    if (fusion) {
+        GGML_ASSERT( !ids || dst->ne[2] == 1);
+        GGML_ASSERT(  ids || dst->ne[1] == 1);
+        if (fusion->x_bias) {
+            GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
+            GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
+            GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
+            fusion_local.x_bias = fusion->x_bias->data;
+        }
+        if (fusion->gate) {
+            GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
+            fusion_local.gate = fusion->gate->data;
+        }
+        if (fusion->gate_bias) {
+            GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
+            GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
+            GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
+            fusion_local.gate_bias = fusion->gate_bias->data;
+        }
+        fusion_local.glu_op = fusion->glu_op;
+    }
+
     const int64_t s01 = src0->nb[1] / ts_src0;
     const int64_t s11 = src1->nb[1] / ts_src1;
     const int64_t s1  =  dst->nb[1] / ts_dst;
@@ -370,19 +630,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
@@ -409,7 +669,6 @@ void ggml_cuda_op_mul_mat_vec_f(
     const int cc = ggml_cuda_info().devices[id].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
 
-
     // ggml_cuda_op provides single, contiguous matrices
     const int64_t stride_row         = ne00;
     const int64_t stride_col_y       = ne10;
@@ -426,22 +685,23 @@ void ggml_cuda_op_mul_mat_vec_f(
     const int64_t stride_sample_y    = 0;
     const int64_t stride_sample_dst  = 0;
 
+    ggml_cuda_mm_fusion_args_device empty{};
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0_dd_i;
-            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
-            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
-            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
index 1da460992e7841b5c698ea44d3893cd860e07db6..a205aa8e4c5380377a1f43d635563bf4a8ae116b 100644 (file)
@@ -1,6 +1,7 @@
 #include "common.cuh"
 
-void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
+    const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
 
 void ggml_cuda_op_mul_mat_vec_f(
     ggml_backend_cuda_context & ctx,
index 3bf0c9ed25038785dfc16c859d3e1318f88de095..7a783e4fcf9b45cbf206e17f9d9dcc274881892e 100644 (file)
@@ -1,5 +1,6 @@
 #include "mmvq.cuh"
 #include "quantize.cuh"
+#include "unary.cuh"
 #include "vecdotq.cuh"
 
 #include <cstdint>
@@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
     return MMVQ_PARAMETERS_GENERIC;
 }
 
-static constexpr __host__ __device__ int calc_nwarps(int ncols_dst,  mmvq_parameter_table_id table_id) {
+static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
     if (table_id == MMVQ_PARAMETERS_GENERIC) {
         switch (ncols_dst) {
             case 1:
@@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
     return 1;
 }
 
-template <ggml_type type, int ncols_dst>
 // tell the compiler to use as many registers as it wants, see nwarps definition below
+template <ggml_type type, int ncols_dst, bool has_fusion>
 __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
-        const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
         const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
         const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
@@ -169,8 +170,38 @@ static __global__ void mul_mat_vec_q(
     const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);
     const uint32_t sample_y    = sample_dst;
 
+    bool use_gate = false;
+    bool use_bias = false;
+    bool use_gate_bias = false;
+    const void * vgate = nullptr;
+    const float * x_bias = nullptr;
+    const float * gate_bias = nullptr;
+    ggml_glu_op active_glu;
+
+    if constexpr (has_fusion) {
+        use_gate      = fusion.gate      != nullptr;
+        use_bias      = fusion.x_bias    != nullptr;
+        use_gate_bias = fusion.gate_bias != nullptr && use_gate;
+        vgate         = fusion.gate;
+        x_bias        = (const float *) fusion.x_bias;
+        gate_bias     = (const float *) fusion.gate_bias;
+        active_glu    = fusion.glu_op;
+    }
+
+    const uint32_t channel_bias = ids ? channel_x : channel_dst;
+
+    if constexpr (has_fusion) {
+        if (use_bias) {
+            x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
+        }
+        if (use_gate_bias) {
+            gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
+        }
+    }
+
     // partial sum for each thread
     float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
+    float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
 
     const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
     const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
@@ -187,17 +218,35 @@ static __global__ void mul_mat_vec_q(
             for (int i = 0; i < rows_per_cuda_block; ++i) {
                 tmp[j][i] += vec_dot_q_cuda(
                     vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        tmp_gate[j][i] += vec_dot_q_cuda(
+                            vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
+                    }
+                }
             }
         }
     }
 
     __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
+    __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
+    if constexpr (!has_fusion) {
+        (void) tmp_shared_gate;
+    } else if (!use_gate) {
+        (void) tmp_shared_gate;
+    }
+
     if (threadIdx.y > 0) {
 #pragma unroll
         for (int j = 0; j < ncols_dst; ++j) {
 #pragma unroll
             for (int i = 0; i < rows_per_cuda_block; ++i) {
                 tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
+                    }
+                }
             }
         }
     }
@@ -216,12 +265,49 @@ static __global__ void mul_mat_vec_q(
 #pragma unroll
             for (int l = 0; l < nwarps-1; ++l) {
                 tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
+                if constexpr (has_fusion) {
+                    if (use_gate) {
+                        tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
+                    }
+                }
             }
             tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
+            if constexpr (has_fusion) {
+                if (use_gate) {
+                    tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
+                }
+            }
         }
 
         if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
-            dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
+            float result = tmp[j][threadIdx.x];
+            if constexpr (has_fusion) {
+                if (use_bias) {
+                    result += x_bias[j*stride_col_dst + threadIdx.x];
+                }
+                if (use_gate) {
+                    float gate_value = tmp_gate[j][threadIdx.x];
+                    if (use_gate_bias) {
+                        gate_value += gate_bias[j*stride_col_dst + threadIdx.x];
+                    }
+                    switch (active_glu) {
+                        case GGML_GLU_OP_SWIGLU:
+                            result *= ggml_cuda_op_silu_single(gate_value);
+                            break;
+                        case GGML_GLU_OP_GEGLU:
+                            result *= ggml_cuda_op_gelu_single(gate_value);
+                            break;
+                        case GGML_GLU_OP_SWIGLU_OAI: {
+                            result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
+                            break;
+                        }
+                        default:
+                            result = result * gate_value;
+                            break;
+                    }
+                }
+            }
+            dst[j*stride_col_dst + threadIdx.x] = result;
         }
     }
 }
@@ -235,9 +321,37 @@ static std::pair<dim3, dim3> calc_launch_params(
     return {block_nums, block_dims};
 }
 
+template<ggml_type type, int c_ncols_dst>
+static void mul_mat_vec_q_switch_fusion(
+        const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
+        const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
+        const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
+        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
+        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
+
+    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+    if constexpr (c_ncols_dst == 1) {
+        if (has_fusion) {
+            mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            return;
+        }
+    }
+
+    GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
+
+    mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
+        (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+}
+
 template <ggml_type type>
 static void mul_mat_vec_q_switch_ncols_dst(
-        const void * vx, const void * vy, const int32_t * ids, float * dst,
+        const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const int ncols_x, const int nrows_x, const int ncols_dst,
         const int stride_row_x, const int stride_col_y, const int stride_col_dst,
         const int nchannels_x, const int nchannels_y, const int nchannels_dst,
@@ -256,80 +370,83 @@ static void mul_mat_vec_q_switch_ncols_dst(
     const int warp_size = ggml_cuda_info().devices[device].warp_size;
     const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
+    const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+
     GGML_ASSERT(!ids || ncols_dst == 1);
     switch (ncols_dst) {
         case 1: {
             constexpr int c_ncols_dst = 1;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 dims.first, dims.second, 0, stream);
         } break;
         case 2: {
             constexpr int c_ncols_dst = 2;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 dims.first, dims.second, 0, stream);
         } break;
         case 3: {
             constexpr int c_ncols_dst = 3;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 dims.first, dims.second, 0, stream);
         } break;
         case 4: {
             constexpr int c_ncols_dst = 4;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 dims.first, dims.second, 0, stream);
         } break;
         case 5: {
             constexpr int c_ncols_dst = 5;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 dims.first, dims.second, 0, stream);
         } break;
         case 6: {
             constexpr int c_ncols_dst = 6;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 dims.first, dims.second, 0, stream);
         } break;
         case 7: {
             constexpr int c_ncols_dst = 7;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 dims.first, dims.second, 0, stream);
         } break;
         case 8: {
             constexpr int c_ncols_dst = 8;
             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
-            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
-                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+            mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 dims.first, dims.second, 0, stream);
         } break;
         default:
             GGML_ABORT("fatal error");
             break;
     }
-}
 
+    GGML_UNUSED(has_fusion);
+}
 static void mul_mat_vec_q_switch_type(
-        const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
+        const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const int ncols_x, const int nrows_x, const int ncols_dst,
         const int stride_row_x, const int stride_col_y, const int stride_col_dst,
         const int nchannels_x, const int nchannels_y, const int nchannels_dst,
@@ -339,143 +456,123 @@ static void mul_mat_vec_q_switch_type(
     switch (type_x) {
         case GGML_TYPE_Q4_0:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q4_1:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q5_0:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q5_1:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q8_0:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_MXFP4:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q2_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q3_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q4_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q5_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_Q6_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ2_XXS:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ2_XS:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ2_S:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ3_XXS:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ1_S:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ1_M:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ4_NL:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ4_XS:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         case GGML_TYPE_IQ3_S:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
-                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -484,7 +581,8 @@ static void mul_mat_vec_q_switch_type(
 }
 
 void ggml_cuda_mul_mat_vec_q(
-        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
+        const ggml_cuda_mm_fusion_args_host * fusion) {
     GGML_ASSERT(        src1->type == GGML_TYPE_F32);
     GGML_ASSERT(        dst->type  == GGML_TYPE_F32);
     GGML_ASSERT(!ids || ids->type  == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
@@ -508,6 +606,31 @@ void ggml_cuda_mul_mat_vec_q(
     const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
     float         *  dst_d =       (float         *)  dst->data;
 
+    ggml_cuda_mm_fusion_args_device fusion_local{};
+
+    if (fusion) {
+        GGML_ASSERT( !ids || dst->ne[2] == 1);
+        GGML_ASSERT(  ids || dst->ne[1] == 1);
+
+        if (fusion->x_bias) {
+            GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
+            GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
+            GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
+            fusion_local.x_bias = fusion->x_bias->data;
+        }
+        if (fusion->gate) {
+            GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
+            fusion_local.gate = fusion->gate->data;
+        }
+        if (fusion->gate_bias) {
+            GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
+            GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
+            GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
+            fusion_local.gate_bias = fusion->gate_bias->data;
+        }
+        fusion_local.glu_op = fusion->glu_op;
+    }
+
     // If src0 is a temporary compute buffer, clear any potential padding.
     if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
         const size_t size_data  = ggml_nbytes(src0);
@@ -549,10 +672,10 @@ void ggml_cuda_mul_mat_vec_q(
     const int64_t stride_channel_y   = ids ? s11  : s12;
 
     mul_mat_vec_q_switch_type(
-        src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
+        src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
         ne01,              ncols_dst,     s01, stride_col_y,     stride_col_dst,
         ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-        ne03,              ne3,           s03, s13,              s3,                 stream);
+        ne03,              ne3,           s03, s13,              s3,               stream);
 }
 
 void ggml_cuda_op_mul_mat_vec_q(
@@ -578,8 +701,9 @@ void ggml_cuda_op_mul_mat_vec_q(
     const int stride_row_x = ne00 / ggml_blck_size(src0->type);
     const int stride_col_y = src1_padded_row_size / QK8_1;
 
+    ggml_cuda_mm_fusion_args_device fusion_local{};
     mul_mat_vec_q_switch_type(
-        src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
+        src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
 
     GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
index 39dc7d33eb5ac68aadd31ada9d02639fe5a13372..4bb10cfaec2b6113e7f1c79c268fc014b66419cd 100644 (file)
@@ -3,7 +3,7 @@
 #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
 
 void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
-    const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+    const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
 
 void ggml_cuda_op_mul_mat_vec_q(
     ggml_backend_cuda_context & ctx,
index 3c564566a51ff8a76ba66a16d90ecba7dd606037..5f0d3a6726aefd4806ca6cf54ceb63a6c9e49ac7 100644 (file)
@@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) {
 }
 
 static __device__ __forceinline__ float op_gelu(float x) {
-    const float GELU_COEF_A    = 0.044715f;
-    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
-    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+    return ggml_cuda_op_gelu_single(x);
 }
 
 static __device__ __forceinline__ float op_gelu_erf(float x) {
@@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) {
 }
 
 static __device__ __forceinline__ float op_silu(float x) {
-    return x / (1.0f + expf(-x));
+    return ggml_cuda_op_silu_single(x);
 }
 
 static __device__ __forceinline__ float op_tanh(float x) {
@@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons
 
     float xi = x[j0];
     float gi = g[j1];
-    xi = fminf(xi, limit);
-    gi = fmaxf(fminf(gi, limit), -limit);
-
-    float out_glu = xi / (1.0f + expf(-xi * alpha));
-    out_glu = out_glu * (1.0f + gi);
 
-    dst[i] = out_glu;
+    dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
 }
 
 template <typename T>
index 8e7644fcd9a484492283ffaa50113b54e09d61dd..6c738cefecfd21cc0113228044792f813a576c3c 100644 (file)
@@ -1,3 +1,4 @@
+#pragma once
 #include "common.cuh"
 
 #define CUDA_NEG_BLOCK_SIZE 256
@@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
+    return x / (1.0f + expf(-x));
+}
+
+__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {
+    const float GELU_COEF_A    = 0.044715f;
+    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+    return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));
+}
+
+__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
+    x = fminf(x, limit);
+    g = fmaxf(fminf(g, limit), -limit);
+
+    float out_glu = x / (1.0f + expf(-x * alpha));
+    out_glu = out_glu * (1.0f + g);
+    return out_glu;
+}
index 9eb2b66879c0bf7d4b15d3993d774e5ac766863a..33ac27ff5ca00df96147686516608e8e4e5e1924 100644 (file)
@@ -4721,6 +4721,140 @@ struct test_topk_moe: public test_case {
     }
 };
 
+struct test_mul_mat_vec_fusion : public test_case {
+    const ggml_type type;
+    const ggml_glu_op glu_op;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const bool use_id;
+    const int n_mats;
+    const int n_used;
+    const bool b;        // broadcast b matrix (only for use_id)
+    const bool with_bias;
+    const bool with_gate;
+
+    test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
+                        bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
+    : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
+        if (use_id) {
+            GGML_ASSERT(n_used <= n_mats);
+        }
+    }
+
+    std::string vars() override {
+        return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
+    }
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "MUL_MAT_VEC_FUSION";
+    }
+
+    bool run_whole_graph() override { return true; }
+
+    ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) {
+        ggml_tensor * out = nullptr;
+        if (with_gate) {
+            if (glu_op == GGML_GLU_OP_SWIGLU_OAI) {
+                constexpr float alpha = 1.702f;
+                constexpr float limit = 7.0f;
+                out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit);
+            } else {
+                out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
+            }
+        }
+        return out;
+    }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        if (!use_id) {
+            std::array<int64_t, 4> ne = {k, m, 1, 1};
+            std::array<int64_t, 4> ne0 = {k, n, 1, 1};
+
+            ggml_tensor * cur  = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
+            ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr;
+            ggml_tensor * up   = ggml_new_tensor(ctx, type, 4, ne0.data());
+
+            ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);
+            if (with_bias) {
+                std::array<int64_t, 4> bias_ne = {ffn_up->ne[0], 1, 1, 1};
+                ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
+                ffn_up = ggml_add(ctx, ffn_up, up_bias);
+            }
+
+            ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr;
+            if (with_bias && with_gate) {
+                std::array<int64_t, 4> bias_ne = {ffn_gate->ne[0], 1, 1, 1};
+                ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
+                ffn_gate = ggml_add(ctx, ffn_gate, gate_bias);
+            }
+
+            ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
+            ggml_set_name(out, "out");
+            return out;
+        } else {
+            ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
+            ggml_tensor * ups   = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
+            ggml_tensor * ids   = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m);
+
+            if (n_used != n_mats) {
+                ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0);
+            }
+
+            ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m);
+            ggml_set_name(cur, "cur");
+
+            ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids);
+            if (with_bias) {
+                ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats);
+                ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids);
+            }
+
+            ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr;
+            if (with_bias && with_gate) {
+                ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats);
+                ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids);
+            }
+
+            ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
+            ggml_set_name(out, "out");
+            return out;
+        }
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        if (!use_id) {
+            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+                init_tensor_uniform(t);
+            }
+        } else {
+            std::random_device rd;
+            std::default_random_engine rng(rd());
+            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+                if (t->type == GGML_TYPE_I32) {
+                    if (ggml_is_view_op(t->op)) { continue; }
+                    // ids
+                    for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                        std::vector<int32_t> data(t->ne[0]);
+                        for (int i = 0; i < t->ne[0]; i++) {
+                            data[i] = i % n_mats;
+                        }
+                        std::shuffle(data.begin(), data.end(), rng);
+                        ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+                    }
+                } else {
+                    init_tensor_uniform(t);
+                }
+            }
+        }
+    }
+
+    double max_nmse_err() override {
+        return 5e-3;
+    }
+};
+
 // GGML_OP_SUM
 struct test_sum : public test_case {
     const ggml_type type;
@@ -6983,6 +7117,33 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
     test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
 
+    for (ggml_type type : base_types) {
+        for (bool with_gate : {false, true}) {
+            for (bool use_id : {false, true}) {
+                for (bool b : {false, true}) {
+                    if (!use_id && b) {
+                        continue;
+                    }
+                    for (bool with_bias : {false, true}) {
+                        if (!with_gate && !with_bias) {
+                            continue;
+                        }
+                        for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) {
+                            if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) {
+                                continue;
+                            }
+                            if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) {
+                                continue;
+                            }
+                            test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
+                                use_id, 16, 8, b, with_bias, with_gate));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
     for (bool with_norm : {false, true}) {
         test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
         test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));