#include <memory>
#include <limits>
#include <map>
+#include <set>
#include <unordered_map>
#include <memory>
#include <mutex>
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
uint32_t padded_N;
};
+
+#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
+#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
+#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
+#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
+
struct vk_mat_vec_push_constants {
uint32_t ncols;
uint32_t stride_a;
uint32_t batch_stride_a;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
- uint32_t enable_bias;
- uint32_t enable_scale;
+ uint32_t fusion_flags;
uint32_t ne02;
uint32_t ne12;
uint32_t broadcast2;
uint32_t nchannels_y;
uint32_t b_offset;
uint32_t d_offset;
- uint32_t enable_bias;
+ uint32_t fusion_flags;
};
struct vk_mat_vec_nc_push_constants {
uint32_t nb03;
uint32_t nb13;
uint32_t nb23;
- uint32_t enable_bias;
+ uint32_t fusion_flags;
};
struct vk_mat_mat_id_push_constants {
uint32_t batch_stride_a;
uint32_t batch_stride_b;
uint32_t batch_stride_d;
- uint32_t enable_bias;
- uint32_t enable_scale;
+ uint32_t fusion_flags;
uint32_t nei0;
uint32_t ne11;
};
const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
- static constexpr uint32_t mul_mat_vec_num_bindings = 4;
- static constexpr uint32_t mul_mat_vec_id_num_bindings = 5;
+ static constexpr uint32_t mul_mat_vec_num_bindings = 5;
+ static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;
for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
groups_x = CEIL_DIV(groups_x, groups_z);
}
- uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+ uint32_t fusion_flags = 0;
- vk_subbuffer d_B = d_D;
-
- if (enable_bias) {
+ vk_subbuffer d_F0 = d_D;
+ if (ctx->num_additional_fused_ops > 0) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
- d_B = ggml_vk_tensor_subbuffer(ctx, bias);
+ d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
+ }
+
+ vk_subbuffer d_F1 = d_D;
+ if (ctx->num_additional_fused_ops == 2) {
+ const ggml_tensor * add = cgraph->nodes[node_idx + 2];
+ const ggml_tensor * bias = add->src[0] == cgraph->nodes[node_idx + 1] ? add->src[1] : add->src[0];
+
+ d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
- stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
+ stride_batch_x, stride_batch_y, stride_batch_d,
+ fusion_flags,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
d_X,
d_Y,
d_D,
- d_B,
+ d_F0,
+ d_F1,
},
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
- vk_subbuffer d_B = d_D;
+ vk_subbuffer d_F0 = d_D;
- uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+ uint32_t fusion_flags = 0;
- if (enable_bias) {
+ if (ctx->num_additional_fused_ops > 0) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
- d_B = ggml_vk_tensor_subbuffer(ctx, bias);
+ d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
+ }
+
+ vk_subbuffer d_F1 = d_D;
+ if (ctx->num_additional_fused_ops > 1) {
+ const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
+
+ d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
vk_mat_vec_p021_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12,
- 0, 0, enable_bias
+ 0, 0, fusion_flags
};
init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
d_Qx,
d_Qy,
d_D,
- d_B,
+ d_F0,
+ d_F1,
}, pc, { 1, (uint32_t)ne01, workgroups_z });
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
- vk_subbuffer d_B = d_D;
+ vk_subbuffer d_F0 = d_D;
- uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+ uint32_t fusion_flags = 0;
- if (enable_bias) {
+ if (ctx->num_additional_fused_ops > 0) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
- d_B = ggml_vk_tensor_subbuffer(ctx, bias);
+ d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
+ }
+
+ vk_subbuffer d_F1 = d_D;
+ if (ctx->num_additional_fused_ops > 1) {
+ const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
+
+ d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
}
// compute
row_stride_x, channel_stride_x, channel_stride_y,
(uint32_t)(ne12 / ne02), (uint32_t)ne12,
0, 0,
- nb03, nb13, nb23, enable_bias
+ nb03, nb13, nb23, fusion_flags
};
init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
d_Qx,
d_Qy,
d_D,
- d_B,
+ d_F0,
+ d_F1,
}, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
}
vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
vk_subbuffer d_ids = ggml_vk_tensor_subbuffer(ctx, ids);
- vk_subbuffer d_B = d_D;
+ vk_subbuffer d_F0 = d_D;
vk_subbuffer d_X, d_Y;
if (qx_needs_dequant) {
groups_x = CEIL_DIV(groups_x, groups_z);
}
- uint32_t enable_bias = 0;
- uint32_t enable_scale = 0;
+ uint32_t fusion_flags = 0;
+
if (ctx->num_additional_fused_ops > 0) {
+ const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
+
+ d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
+
if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
- enable_scale = 1;
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE0;
} else {
GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
- enable_bias = 1;
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
}
}
- if (enable_bias || enable_scale) {
- const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
+ vk_subbuffer d_F1 = d_D;
+ if (ctx->num_additional_fused_ops > 1) {
+ const ggml_tensor * scale = cgraph->nodes[node_idx + 2]->src[1];
- d_B = ggml_vk_tensor_subbuffer(ctx, bias);
+ d_F1 = ggml_vk_tensor_subbuffer(ctx, scale);
+ fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
}
// compute
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
-
- enable_bias, enable_scale,
-
+ fusion_flags,
(uint32_t)nei0, (uint32_t)ne11,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
d_X,
d_Y,
d_D,
- d_B,
+ d_F0,
+ d_F1,
d_ids,
},
pc, { groups_x, (uint32_t)nei0, groups_z });
return false;
}
}
- if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
- // additional constraints specific to this fusion
- const ggml_tensor *mul = cgraph->nodes[node_idx];
- const ggml_tensor *add = cgraph->nodes[node_idx + 1];
+ auto const &mm_add_ok = [&](const ggml_tensor *mul, const ggml_tensor *add) {
const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0];
// mat-vec only
if (get_misalign_bytes(ctx, bias) != 0) {
return false;
}
- }
- if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
+ return true;
+ };
+
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
// additional constraints specific to this fusion
const ggml_tensor *mul = cgraph->nodes[node_idx];
const ggml_tensor *add = cgraph->nodes[node_idx + 1];
- const ggml_tensor *bias = add->src[1];
- if (mul != add->src[0]) {
+ if (!mm_add_ok(mul, add)) {
+ return false;
+ }
+ if (ops.size() == 3) {
+ if (ops.begin()[2] != GGML_OP_ADD) {
+ return false;
+ }
+ if (!mm_add_ok(add, cgraph->nodes[node_idx + 2])) {
+ return false;
+ }
+ }
+ }
+
+ auto const &mmid_mul_ok = [&](const ggml_tensor *mmid, const ggml_tensor *mul) {
+ const ggml_tensor *scale = mul->src[1];
+
+ if (mmid != mul->src[0]) {
return false;
}
// mat-vec only
return false;
}
// shaders assume the types match
- if (mul->type != bias->type) {
+ if (mmid->type != scale->type) {
return false;
}
// shaders assume the bias is contiguous
- if (!ggml_is_contiguous(bias)) {
+ if (!ggml_is_contiguous(scale)) {
return false;
}
- // the ID tensor must be the same for mul_mat_id and add_id
- if (mul->src[2] != add->src[2]) {
+ // unaligned bias isn't handled
+ if (get_misalign_bytes(ctx, scale) != 0) {
return false;
}
- // unaligned bias isn't handled
- if (get_misalign_bytes(ctx, bias) != 0) {
+ // shader only indexes by expert index
+ if (scale->ne[0] != 1 ||
+ scale->ne[1] != mul->ne[1] ||
+ scale->ne[2] != 1 ||
+ scale->ne[3] != 1) {
return false;
}
- }
+ return true;
+ };
- if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
// additional constraints specific to this fusion
- const ggml_tensor *mmid = cgraph->nodes[node_idx];
- const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
- const ggml_tensor *scale = mul->src[1];
+ const ggml_tensor *mul = cgraph->nodes[node_idx];
+ const ggml_tensor *add = cgraph->nodes[node_idx + 1];
+ const ggml_tensor *bias = add->src[1];
- if (mmid != mul->src[0]) {
+ if (mul != add->src[0]) {
return false;
}
// mat-vec only
return false;
}
// shaders assume the types match
- if (mmid->type != scale->type) {
+ if (mul->type != bias->type) {
return false;
}
// shaders assume the bias is contiguous
- if (!ggml_is_contiguous(scale)) {
+ if (!ggml_is_contiguous(bias)) {
+ return false;
+ }
+ // the ID tensor must be the same for mul_mat_id and add_id
+ if (mul->src[2] != add->src[2]) {
return false;
}
// unaligned bias isn't handled
- if (get_misalign_bytes(ctx, scale) != 0) {
+ if (get_misalign_bytes(ctx, bias) != 0) {
return false;
}
- // shader only indexes by expert index
- if (scale->ne[0] != 1 ||
- scale->ne[1] != mul->ne[1] ||
- scale->ne[2] != 1 ||
- scale->ne[3] != 1) {
+
+ if (ops.size() == 3) {
+ if (ops.begin()[2] != GGML_OP_MUL) {
+ return false;
+ }
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 2];
+ return mmid_mul_ok(add, mul);
+ }
+ }
+
+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
+ // additional constraints specific to this fusion
+ const ggml_tensor *mmid = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+
+ if (!mmid_mul_ok(mmid, mul)) {
return false;
}
}
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
if (num_adds) {
ctx->num_additional_fused_ops = num_adds - 1;
+ } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
+ ctx->num_additional_fused_ops = 2;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
ctx->num_additional_fused_ops = 1;
+ } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
+ ctx->num_additional_fused_ops = 2;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
std::vector<ggml_tensor *> new_order;
std::vector<bool> used(graph->n_nodes, false);
+ std::set<ggml_tensor *> used_node_set;
+
int first_unused = 0;
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
if (match_pattern(pattern, first_unused)) {
for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
+ used_node_set.insert(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
used[set_rows_idx] = true;
}
}
+ // Look for MUL_MAT_ID + ADD_ID + MUL
+ if (j > 0 &&
+ graph->nodes[j]->op == GGML_OP_ADD_ID &&
+ graph->nodes[j-1]->op == GGML_OP_MUL_MAT_ID) {
+ for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
+ if (graph->nodes[k]->op == GGML_OP_MUL &&
+ graph->nodes[k]->src[0] == graph->nodes[j] &&
+ // src1 must either be weights or already processed
+ (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
+ current_set.push_back(k);
+ used[k] = true;
+ break;
+ }
+ }
+ }
+ // Look for MUL_MAT + ADD + ADD
+ if (j > 0 &&
+ graph->nodes[j]->op == GGML_OP_ADD &&
+ graph->nodes[j-1]->op == GGML_OP_MUL_MAT) {
+ for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
+ if (graph->nodes[k]->op == GGML_OP_ADD &&
+ graph->nodes[k]->src[0] == graph->nodes[j] &&
+ // src1 must either be weights or already processed
+ (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
+ current_set.push_back(k);
+ used[k] = true;
+ break;
+ }
+ }
+ }
}
}
// Second pass grabs view nodes.
// Push the current set into new_order
for (auto c : current_set) {
new_order.push_back(graph->nodes[c]);
+ used_node_set.insert(graph->nodes[c]);
used[c] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
#define EXPERT_COUNT 8
#endif
-#include "types.glsl"
-
-#ifndef MMQ
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-#else
-layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
-#endif
-
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-#ifdef B_TYPE_VEC2
-layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
-#endif
-#ifdef B_TYPE_VEC4
-layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
-#endif
-
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
-
-#ifdef MUL_MAT_ID
-layout (binding = 4) readonly buffer IDS {int data_ids[];};
-#endif
+#include "mul_mat_vec_iface.glsl"
#include "dequant_funcs.glsl"
uint batch_stride_b;
uint batch_stride_d;
- uint enable_bias;
- uint enable_scale;
+ uint fusion_flags;
#ifdef MUL_MAT_ID
uint nei0;
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
- if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
- temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
-#else
- temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
-#endif
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+ temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
-#ifdef MUL_MAT_ID
- if (p.enable_scale != 0) {
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
- temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
+ temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+ }
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
+ const uint expert_idx = gl_GlobalInvocationID.y;
+ temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+ }
+#else
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+ temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
+ }
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+ temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
temp[j][n] += tmpsh[j][n][s];
}
- if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
- temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
-#else
- temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
-#endif
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+ temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
-#ifdef MUL_MAT_ID
- if (p.enable_scale != 0) {
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
- temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
+ temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+ }
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
+ const uint expert_idx = gl_GlobalInvocationID.y;
+ temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+ }
+#else
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+ temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
+ }
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+ temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
- if (p.enable_bias != 0) {
#ifdef MUL_MAT_ID
- tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
-#else
- tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
-#endif
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+ tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
-#ifdef MUL_MAT_ID
- if (p.enable_scale != 0) {
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
+ const uint expert_idx = gl_GlobalInvocationID.y;
+ tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+ }
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
- tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
+ tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+ }
+#else
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+ tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
+ }
+ if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+ tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
}
#endif
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
const bool b; // broadcast b matrix (only for use_id)
const bool with_bias;
const bool with_gate;
+ std::array<int64_t, 2> batch_dims;
test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
- bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
- : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
+ bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true,
+ std::array<int64_t, 2> batch_dims = {4, 2})
+ : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate), batch_dims(batch_dims) {
if (use_id) {
GGML_ASSERT(n_used <= n_mats);
}
}
std::string vars() override {
- return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
+ return VARS_TO_STR12(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate, batch_dims);
}
std::string op_desc(ggml_tensor * t) override {
ggml_tensor * build_graph(ggml_context * ctx) override {
if (!use_id) {
- const int channels = 4;
- const int samples = 2;
+ const int channels = batch_dims[0];
+ const int samples = batch_dims[1];
std::array<int64_t, 4> ne = { k, m, channels, samples };
std::array<int64_t, 4> ne0 = { k, n, channels, samples };
}
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
+
+ std::array<int64_t, 4> bias2_ne = { out->ne[0], 1, channels, samples };
+ ggml_tensor * bias2 = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias2_ne.data());
+ out = ggml_add(ctx, out, bias2);
+
ggml_set_name(out, "out");
return out;
} else {
}
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
+
+ std::array<int64_t, 4> scale_ne { 1, out->ne[1], out->ne[2], out->ne[3] };
+ ggml_tensor * scale = ggml_new_tensor(ctx, out->type, 4, scale_ne.data());
+ out = ggml_mul(ctx, out, scale);
+
ggml_set_name(out, "out");
return out;
}
}
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
use_id, 16, 8, b, with_bias, with_gate));
+ test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
+ use_id, 16, 8, b, with_bias, with_gate, {1, 1}));
}
}
}