uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t enable_bias;
+ uint32_t enable_scale;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t enable_bias;
+ uint32_t enable_scale;
uint32_t nei0;
uint32_t ne11;
};
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
- stride_batch_x, stride_batch_y, stride_batch_d, enable_bias,
+ stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
groups_x = CEIL_DIV(groups_x, groups_z);
}
- uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+ uint32_t enable_bias = 0;
+ uint32_t enable_scale = 0;
+ if (ctx->num_additional_fused_ops > 0) {
+ if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
+ enable_scale = 1;
+ } else {
+ GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
+ enable_bias = 1;
+ }
+ }
vk_buffer d_B = d_D;
size_t b_buf_offset = 0;
uint64_t b_sz = 0;
- if (enable_bias) {
+ if (enable_bias || enable_scale) {
const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
bool b_uma = false;
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
- enable_bias,
+ enable_bias, enable_scale,
(uint32_t)nei0, (uint32_t)ne11,
};
}
}
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
+ // additional constraints specific to this fusion
+ const ggml_tensor *mmid = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+ const ggml_tensor *scale = mul->src[1];
+
+ if (mmid != mul->src[0]) {
+ return false;
+ }
+ // mat-vec only
+ if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
+ return false;
+ }
+ // shaders assume the types match
+ if (mmid->type != scale->type) {
+ return false;
+ }
+ // shaders assume the bias is contiguous
+ if (!ggml_is_contiguous(scale)) {
+ return false;
+ }
+ // unaligned bias isn't handled
+ if (get_misalign_bytes(ctx, scale) != 0) {
+ return false;
+ }
+ // shader only indexes by expert index
+ if (scale->ne[0] != 1 ||
+ scale->ne[1] != mul->ne[1] ||
+ scale->ne[2] != 1 ||
+ scale->ne[3] != 1) {
+ return false;
+ }
+ }
+
return true;
}
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
ctx->num_additional_fused_ops = 1;
+ } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
+ ctx->num_additional_fused_ops = 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
is_src_of(graph->nodes[j], graph->nodes[c]) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
- !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) {
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
ok = false;
break;
}
uint batch_stride_d;
uint enable_bias;
+ uint enable_scale;
#ifdef MUL_MAT_ID
uint nei0;
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
}
+#ifdef MUL_MAT_ID
+ if (p.enable_scale != 0) {
+ const uint expert_idx = gl_GlobalInvocationID.y;
+ temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
+ }
+#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
}
+#ifdef MUL_MAT_ID
+ if (p.enable_scale != 0) {
+ const uint expert_idx = gl_GlobalInvocationID.y;
+ temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
+ }
+#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
#endif
}
+#ifdef MUL_MAT_ID
+ if (p.enable_scale != 0) {
+ const uint expert_idx = gl_GlobalInvocationID.y;
+ tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
+ }
+#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
}
}