From: Jeff Bolz Date: Sun, 9 Nov 2025 08:48:42 +0000 (-0600) Subject: vulkan: fuse mul_mat_id + mul (llama/17095) X-Git-Tag: upstream/1.8.3~344 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=db98e8c5b404d8b5f64a00efdd5310d064ec7b65;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp vulkan: fuse mul_mat_id + mul (llama/17095) * vulkan: fuse mul_mat_id + mul This comes up in qwen3 moe. * split mul_mat_id fusion tests into a separate class --- diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6da7bbd2..054e8cbd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -830,6 +830,7 @@ struct vk_mat_vec_push_constants { uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t enable_bias; + uint32_t enable_scale; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; @@ -852,6 +853,7 @@ struct vk_mat_vec_id_push_constants { uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t enable_bias; + uint32_t enable_scale; uint32_t nei0; uint32_t ne11; }; @@ -6863,7 +6865,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& // compute const vk_mat_vec_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, + stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, }; ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, @@ -7684,13 +7686,22 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte groups_x = CEIL_DIV(groups_x, groups_z); } - uint32_t enable_bias = ctx->num_additional_fused_ops > 0; + uint32_t enable_bias = 0; + uint32_t enable_scale = 0; + if (ctx->num_additional_fused_ops > 0) { + if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) { + enable_scale = 1; + } else { + GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID); + enable_bias = 1; + } + } vk_buffer d_B = d_D; size_t b_buf_offset = 0; uint64_t b_sz = 0; - if (enable_bias) { + if (enable_bias || enable_scale) { const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1]; bool b_uma = false; @@ -7712,7 +7723,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), - enable_bias, + enable_bias, enable_scale, (uint32_t)nei0, (uint32_t)ne11, }; @@ -12490,6 +12501,40 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g } } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *mmid = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + const ggml_tensor *scale = mul->src[1]; + + if (mmid != mul->src[0]) { + return false; + } + // mat-vec only + if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) { + return false; + } + // shaders assume the types match + if (mmid->type != scale->type) { + return false; + } + // shaders assume the bias is contiguous + if (!ggml_is_contiguous(scale)) { + return false; + } + // unaligned bias isn't handled + if (get_misalign_bytes(ctx, scale) != 0) { + return false; + } + // shader only indexes by expert index + if (scale->ne[0] != 1 || + scale->ne[1] != mul->ne[1] || + scale->ne[2] != 1 || + scale->ne[3] != 1) { + return false; + } + } + return true; } @@ -12798,6 +12843,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = 1; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; + } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && @@ -13033,7 +13080,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * is_src_of(graph->nodes[j], graph->nodes[c]) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) { ok = false; break; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index bbb4d120..eb8fa6dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -49,6 +49,7 @@ layout (push_constant) uniform parameter uint batch_stride_d; uint enable_bias; + uint enable_scale; #ifdef MUL_MAT_ID uint nei0; @@ -129,6 +130,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); #endif } +#ifdef MUL_MAT_ID + if (p.enable_scale != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]); + } +#endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); } } @@ -171,6 +178,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); #endif } +#ifdef MUL_MAT_ID + if (p.enable_scale != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]); + } +#endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); } } @@ -203,6 +216,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]); #endif } +#ifdef MUL_MAT_ID + if (p.enable_scale != 0) { + const uint expert_idx = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]); + } +#endif data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); } }