uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t enable_bias;
+ uint32_t enable_scale;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
uint32_t enable_bias;
+ uint32_t enable_scale;
uint32_t nei0;
uint32_t ne11;
};
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
- stride_batch_x, stride_batch_y, stride_batch_d, enable_bias,
+ stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
groups_x = CEIL_DIV(groups_x, groups_z);
}
- uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+ uint32_t enable_bias = 0;
+ uint32_t enable_scale = 0;
+ if (ctx->num_additional_fused_ops > 0) {
+ if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
+ enable_scale = 1;
+ } else {
+ GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
+ enable_bias = 1;
+ }
+ }
vk_buffer d_B = d_D;
size_t b_buf_offset = 0;
uint64_t b_sz = 0;
- if (enable_bias) {
+ if (enable_bias || enable_scale) {
const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
bool b_uma = false;
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
- enable_bias,
+ enable_bias, enable_scale,
(uint32_t)nei0, (uint32_t)ne11,
};
}
}
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
+ // additional constraints specific to this fusion
+ const ggml_tensor *mmid = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+ const ggml_tensor *scale = mul->src[1];
+
+ if (mmid != mul->src[0]) {
+ return false;
+ }
+ // mat-vec only
+ if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
+ return false;
+ }
+ // shaders assume the types match
+ if (mmid->type != scale->type) {
+ return false;
+ }
+ // shaders assume the bias is contiguous
+ if (!ggml_is_contiguous(scale)) {
+ return false;
+ }
+ // unaligned bias isn't handled
+ if (get_misalign_bytes(ctx, scale) != 0) {
+ return false;
+ }
+ // shader only indexes by expert index
+ if (scale->ne[0] != 1 ||
+ scale->ne[1] != mul->ne[1] ||
+ scale->ne[2] != 1 ||
+ scale->ne[3] != 1) {
+ return false;
+ }
+ }
+
return true;
}
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
ctx->num_additional_fused_ops = 1;
+ } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
+ ctx->num_additional_fused_ops = 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
is_src_of(graph->nodes[j], graph->nodes[c]) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
- !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) {
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
ok = false;
break;
}
}
};
+static void init_mul_mat_id_tensors(ggml_context * ctx, int n_mats) {
+ std::random_device rd;
+ std::default_random_engine rng(rd());
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
+ // ids
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
+ std::vector<int32_t> data(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data[i] = i % n_mats;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+ }
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+}
+
// GGML_OP_MUL_MAT_ID
struct test_mul_mat_id : public test_case {
const ggml_type type_a;
const int64_t m;
const int64_t n;
const int64_t k;
- const uint32_t o; // number of outputs
std::string vars() override {
- return VARS_TO_STR9(type_a, type_b, n_mats, n_used, b, m, n, k, o);
+ return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
}
double max_nmse_err() override {
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 8, int n_used = 2, bool b = false,
- int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1)
+ int64_t m = 32, int64_t n = 32, int64_t k = 32)
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
- m(m), n(n), k(k), o(o) {
+ m(m), n(n), k(k) {
+ GGML_ASSERT(n_used <= n_mats);
+ }
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+ ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
+ ggml_set_name(as, "as");
+
+ ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
+ ggml_set_name(ids, "ids");
+ if (n_used != n_mats) {
+ ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
+ ggml_set_name(ids, "view_of_ids");
+ }
+
+ ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
+ ggml_set_name(b, "b");
+
+ ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
+ ggml_set_name(out, "out");
+
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ init_mul_mat_id_tensors(ctx, n_mats);
+ }
+};
+
+// GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL
+struct test_mul_mat_id_fusion : public test_case {
+ const ggml_type type_a;
+ const ggml_type type_b;
+ const int n_mats;
+ const int n_used;
+ const bool b; // broadcast b matrix
+ const int64_t m;
+ const int64_t n;
+ const int64_t k;
+ const uint32_t o; // number of outputs
+ const bool mul;
+
+ std::string vars() override {
+ return VARS_TO_STR10(type_a, type_b, n_mats, n_used, b, m, n, k, o, mul);
+ }
+
+ double max_nmse_err() override {
+ return 5e-4;
+ }
+
+ uint64_t op_flops(ggml_tensor * t) override {
+ GGML_UNUSED(t);
+ return 2 * m * k * n * n_used;
+ }
+
+ test_mul_mat_id_fusion(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+ int n_mats = 8, int n_used = 2, bool b = false,
+ int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1, bool mul = false)
+ : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
+ m(m), n(n), k(k), o(o), mul(mul) {
GGML_ASSERT(n_used <= n_mats);
}
out = ggml_add(ctx, out, out2);
}
+ if (mul) {
+ std::array<int64_t, 4> ne { 1, out->ne[1], out->ne[2], out->ne[3] };
+ ne[0] = 1;
+ ggml_tensor * m = ggml_new_tensor(ctx, out->type, 4, ne.data());
+ out = ggml_mul(ctx, out, m);
+ }
+
return out;
}
void initialize_tensors(ggml_context * ctx) override {
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
- if (t->type == GGML_TYPE_I32) {
- if (ggml_is_view_op(t->op)) { continue; }
- std::random_device rd;
- std::default_random_engine rng(rd());
- // ids
- for (int64_t r = 0; r < ggml_nrows(t); r++) {
- std::vector<int32_t> data(t->ne[0]);
- for (int i = 0; i < t->ne[0]; i++) {
- data[i] = i % n_mats;
- }
- std::shuffle(data.begin(), data.end(), rng);
- ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
- }
- } else {
- init_tensor_uniform(t);
- }
- }
+ init_mul_mat_id_tensors(ctx, n_mats);
}
- bool run_whole_graph() override { return o > 1; }
+ bool run_whole_graph() override { return true; }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
- return ggml_op_name(GGML_OP_MUL_MAT_ID);
+ return "MUL_MAT_ID_FUSION";
}
};
init_tensor_uniform(t);
}
} else {
- std::random_device rd;
- std::default_random_engine rng(rd());
- for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
- if (t->type == GGML_TYPE_I32) {
- if (ggml_is_view_op(t->op)) { continue; }
- // ids
- for (int64_t r = 0; r < ggml_nrows(t); r++) {
- std::vector<int32_t> data(t->ne[0]);
- for (int i = 0; i < t->ne[0]; i++) {
- data[i] = i % n_mats;
- }
- std::shuffle(data.begin(), data.end(), rng);
- ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
- }
- } else {
- init_tensor_uniform(t);
- }
- }
+ init_mul_mat_id_tensors(ctx, n_mats);
}
}
}
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
- test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
+ test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
// gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
}
}
+ for (int bs : {1, 4, 512}) {
+ for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_K}) {
+ for (ggml_type type_b : {GGML_TYPE_F32}) {
+ // test with mul after (ffn_moe_weighted)
+ test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1, true));
+ }
+ }
+ }
+
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (int n : {1, 16}) {
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
- test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
+ test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
}
}
}
for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
- test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
+ test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
}
}
}
for (int bs : {1, 4, 8, 512}) {
for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
- test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
+ test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
}
}
}