GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
+
+static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
+ GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
+ GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
+ GGML_OP_DIV, GGML_OP_RESHAPE };
+
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
+
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
{ 9, 0, 8 }, // reshape->src[0] == div
};
+//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
+//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
+//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
+//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
+//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
+//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
+//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
+//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
+//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
+//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
+//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
+ { 1, 0, 0 }, // reshape->src[0] == sigmoid
+ { 2, 0, 0 }, // add->src[0] == sigmoid
+ { 3, 0, 2 }, // argsort->src[0] == add
+ { 4, 0, 3 }, // view->src[0] == argsort
+ { 5, 0, 1 }, // get_rows->src[0] == reshape
+ { 5, 1, 4 }, // get_rows->src[1] == view
+ { 6, 0, 5 }, // reshape->src[0] == get_rows
+ { 7, 0, 6 }, // sum_rows->src[0] == reshape
+ { 8, 0, 7 }, // clamp->src[0] == sum_rows
+ { 9, 0, 6 }, // div->src[0] == reshape
+ { 9, 1, 8 }, // div->src[1] == clamp
+ {10, 0, 9 }, // reshape->src[0] == div
+};
+
// same as early_softmax_norm but ending after the get_rows
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
{ 1, 0, 0 }, // reshape->src[0] == softmax
TOPK_MOE_EARLY_SOFTMAX,
TOPK_MOE_EARLY_SOFTMAX_NORM,
TOPK_MOE_LATE_SOFTMAX,
+ TOPK_MOE_SIGMOID_NORM_BIAS,
TOPK_MOE_COUNT,
};
-static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
- topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
- num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
- TOPK_MOE_LATE_SOFTMAX;
- return mode;
-}
-
static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
{ 1, 0, 0 }, // view->src[0] == rope
{ 2, 0, 1 }, // set_rows->src[0] == view
vk_pipeline pipeline_count_experts;
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
std::vector<vk_pipeline_ref> all_pipelines;
uint32_t n_expert_used;
float clamp_min;
float clamp_max;
+ uint32_t gating_func;
+ uint32_t has_bias;
+ uint32_t with_norm;
+ float output_scale;
+ float output_bias;
};
struct vk_op_add_id_push_constants {
// Bit 'i' means nodes[start_of_fusion + i] writes to memory.
// If there's no fusion, bit 0 is still set.
int fused_ops_write_mask {};
+ topk_moe_mode fused_topk_moe_mode {};
+ bool fused_topk_moe_scale {};
// for GGML_VK_PERF_LOGGER
std::unique_ptr<vk_perf_logger> perf_logger;
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
}
}
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
- topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
// use n_experts from push constant if it's not equal to the power of two spec constant
bool use_push = dst->ne[0] != (1u << idx);
- return ctx->device->pipeline_topk_moe[idx][mode][use_push];
+ return ctx->device->pipeline_topk_moe[idx][use_push];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
- topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+ topk_moe_mode mode = ctx->fused_topk_moe_mode;
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
- ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
- (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
- cgraph->nodes[node_idx + 5];
- ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
+ ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
+ ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
+ ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
+ (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
+ cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
+ GGML_ASSERT(bias->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
+ vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
pc.n_rows = n_rows;
pc.n_experts_push = n_experts;
pc.n_expert_used = n_expert_used;
+ pc.clamp_min = -std::numeric_limits<float>::infinity();
+ pc.clamp_max = std::numeric_limits<float>::infinity();
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+ GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+ }
+ if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
+ GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
}
+#define GATING_FUNC_SOFTMAX 0
+#define GATING_FUNC_SIGMOID 1
+#define GATING_FUNC_SOFTMAX_WEIGHT 2
+
+ pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
+ mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
+ GATING_FUNC_SOFTMAX;
+ pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
+ pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
+ if (ctx->fused_topk_moe_scale) {
+ GGML_ASSERT(weights->op == GGML_OP_SCALE);
+ pc.output_scale = ggml_get_op_params_f32(weights, 0);
+ pc.output_bias = ggml_get_op_params_f32(weights, 1);
+ } else {
+ pc.output_scale = 1.0f;
+ pc.output_bias = 0.0f;
+ }
+
GGML_ASSERT(n_expert_used <= n_experts);
const uint32_t rows_per_block = 4;
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
break;
case GGML_OP_UNARY:
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
+ break;
+ }
+
switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SILU:
break;
case GGML_OP_SOFT_MAX:
- if (ctx->num_additional_fused_ops) {
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
} else {
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
break;
case GGML_OP_ARGSORT:
- if (ctx->num_additional_fused_ops) {
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
} else {
ggml_vk_argsort(ctx, compute_ctx, src0, node);
get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2];
break;
+ case TOPK_MOE_SIGMOID_NORM_BIAS:
+ softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
+ weights = cgraph->nodes[node_idx + 10];
+ get_rows = cgraph->nodes[node_idx + 5];
+ argsort = cgraph->nodes[node_idx + 3];
+ if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
+ return false;
+ }
+ // bias is expected to be 1D
+ if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
+ !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
+ return false;
+ }
+ // sigmoid fusion seems to generate infinities on moltenvk
+ if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
+ return false;
+ }
+ break;
case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4];
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
- if (probs != selection_probs) {
+ if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
return false;
}
- const float * op_params = (const float *)softmax->op_params;
-
- float scale = op_params[0];
- float max_bias = op_params[1];
-
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false;
}
- if (scale != 1.0f || max_bias != 0.0f) {
- return false;
- }
+ if (softmax->op == GGML_OP_SOFT_MAX) {
+ const float * op_params = (const float *)softmax->op_params;
- // don't fuse when masks or sinks are present
- if (softmax->src[1] || softmax->src[2]) {
- return false;
+ float scale = op_params[0];
+ float max_bias = op_params[1];
+
+ if (scale != 1.0f || max_bias != 0.0f) {
+ return false;
+ }
+
+ // don't fuse when masks or sinks are present
+ if (softmax->src[1] || softmax->src[2]) {
+ return false;
+ }
}
const int n_expert = softmax->ne[0];
total_mul_mat_bytes += bytes;
}
+ ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
+ ctx->fused_topk_moe_scale = false;
const char *fusion_string {};
if (!ctx->device->disable_fusion) {
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 3;
+ ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
+ ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
+ ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
+ // view of argsort writes to memory
+ ctx->fused_ops_write_mask |= 1 << 4;
+ ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
+ fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 3;
+ ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 1;
+ ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
fusion_string = "TOPK_MOE_LATE_SOFTMAX";
}
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
+ // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
+ if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
+ ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
+ ctx->fused_topk_moe_scale = true;
+ ctx->num_additional_fused_ops++;
+ }
+ }
}
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
if (keep_pattern(topk_moe_early_softmax_norm)) {
continue;
}
+ if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
+ continue;
+ }
if (keep_pattern(topk_moe_early_softmax)) {
continue;
}
}
// Don't pull forward nodes from fusion patterns
if (match_pattern(topk_moe_early_softmax_norm, j) ||
+ match_pattern(topk_moe_sigmoid_norm_bias, j) ||
match_pattern(topk_moe_early_softmax, j) ||
match_pattern(topk_moe_late_softmax, j)) {
continue;
#include "types.glsl"
+#define GATING_FUNC_SOFTMAX 0
+#define GATING_FUNC_SIGMOID 1
+#define GATING_FUNC_SOFTMAX_WEIGHT 2
+
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
float clamp_min;
float clamp_max;
+ uint gating_func;
+ uint has_bias;
+ uint with_norm;
+ float output_scale;
+ float output_bias;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts_spec = 512;
-layout(constant_id = 2) const bool with_norm = true;
-layout(constant_id = 3) const bool late_softmax = false;
-layout(constant_id = 4) const bool nexperts_use_push = false;
+layout(constant_id = 2) const bool nexperts_use_push = false;
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
-layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
-layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
+layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
+layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
+layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
const float INFINITY = 1.0 / 0.0;
}
const uint logits_offset = n_experts * row;
+ const uint bias_offset = 0; // 1D
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
const uint lane = gl_SubgroupInvocationID;
- float wt[experts_per_thread];
+ float probs[experts_per_thread];
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
- wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+ probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+ }
+
+ if (gating_func == GATING_FUNC_SOFTMAX) {
+ softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
+ } else if (gating_func == GATING_FUNC_SIGMOID) {
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ probs[i] = 1.f / (1.f + exp(-probs[i]));
+ }
}
- if (!late_softmax) {
- softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
+ float selection_probs[experts_per_thread];
+ if (has_bias != 0) {
+ [[unroll]]
+ for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+ const uint expert = i + lane;
+ selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
+ }
+ } else {
+ [[unroll]]
+ for (int i = 0; i < experts_per_thread; i++) {
+ selection_probs[i] = probs[i];
+ }
}
// at this point, each thread holds a portion of softmax,
}
for (int k = 0; k < n_expert_used; k++) {
- float max_val = wt[0];
+ float max_val = probs[0];
+ float max_val_s = selection_probs[0];
uint max_expert = lane;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = lane + i * WARP_SIZE;
- if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
- max_val = wt[i];
+ if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
+ max_val = probs[i];
+ max_val_s = selection_probs[i];
max_expert = expert;
}
}
[[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask);
+ const float val_s = subgroupShuffleXor(max_val_s, mask);
const uint expert = subgroupShuffleXor(max_expert, mask);
- if (val > max_val || (val == max_val && expert < max_expert)) {
+ if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
max_val = val;
+ max_val_s = val_s;
max_expert = expert;
}
}
}
if ((max_expert & (WARP_SIZE - 1)) == lane) {
- wt[max_expert / WARP_SIZE] = -INFINITY;
+ selection_probs[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert;
- if (with_norm) {
- wt_sum += max_val;
- }
+ wt_sum += max_val;
}
}
- if (with_norm) {
+ if (with_norm != 0) {
wt_sum = subgroupAdd(wt_sum);
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
}
}
- if (late_softmax) {
+ if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
}
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + lane;
if (idx < n_expert_used) {
- weights[weights_offset + idx] = output_weights[i];
+ weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
}
}
}