#endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes
- std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
- std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
+ std::initializer_list<enum ggml_op> topk_moe_ops =
+ ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
+ std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
+ ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
+ std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
+ ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
if (ops.size() == topk_moe_ops_with_norm.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
}
}
+ if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
+ ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) {
+ ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
+ ggml_tensor * weights = cgraph->nodes[node_idx + 5];
+
+ if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+ return true;
+ }
+ }
+
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
ggml_tensor * weights = cgraph->nodes[i+8];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
- ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
+ /*delayed softmax*/ false);
i += 8;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
ggml_tensor * weights = cgraph->nodes[i+4];
ggml_tensor * selected_experts = cgraph->nodes[i+3];
- ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
+ /*delayed softmax*/ false);
i += 4;
continue;
}
+ if (ggml_cuda_can_fuse(cgraph, i,
+ ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
+ ggml_tensor * weights = cgraph->nodes[i + 5];
+ ggml_tensor * ids = cgraph->nodes[i + 1];
+
+ ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
+ /*delayed_softmax*/ true);
+ i += 5;
+ continue;
+ }
+
if (node->op == GGML_OP_ADD) {
int n_fuse = 0;
ggml_op ops[8];
#include <initializer_list>
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+template <int experts_per_thread, bool use_limit>
+__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+ float max_val = -INFINITY;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ max_val = max(max_val, vals[i]);
+ }
+ }
+
+ max_val = warp_reduce_max(max_val);
+
+ float sum = 0.f;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ const float val = expf(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
+ }
+
+ sum = warp_reduce_sum(sum);
+
+ const float inv_sum = 1.0f / sum;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ vals[i] *= inv_sum;
+ }
+ }
+}
+
/*
This kernel does the following:
- 1. softmax over the logits per token [n_experts, n_tokens]
+ 1. optionally softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory
- 4. optionally normalize the weights
+ 4. optionally normalize the weights or apply softmax over the selected logits
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
-template <int n_experts, bool with_norm>
+template <int n_experts, bool with_norm, bool delayed_softmax = false>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
- float logits_r[experts_per_thread];
+ float wt[experts_per_thread];
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
- const int expert = i + threadIdx.x;
- logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
+ const int expert = i + threadIdx.x;
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
}
- float max_val = logits_r[0];
-
-#pragma unroll
- for (int i = 1; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- max_val = max(val, max_val);
+ if constexpr (!delayed_softmax) {
+ softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
}
- max_val = warp_reduce_max(max_val);
-
- float wt[experts_per_thread];
- float tmp = 0.f;
-
-#pragma unroll
- for (int i = 0; i < experts_per_thread; i++) {
- const float val = logits_r[i];
- wt[i] = expf(val - max_val);
- tmp += wt[i];
- }
+ //at this point, each thread holds either a portion of the softmax distribution
+ //or the raw logits. We do the argmax reduce over n_expert_used, each time marking
+ //the expert weight as -inf to exclude from the next iteration
- tmp = warp_reduce_sum(tmp);
+ float wt_sum = 0.f;
- const float inv_sum = 1.0f / tmp;
+ float output_weights[experts_per_thread];
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
- wt[i] = wt[i] * inv_sum;
+ output_weights[i] = 0.f;
}
- //at this point, each thread holds a portion of softmax,
- //we do the argmax reduce over n_expert_used, each time marking
- //the expert weight as -inf to exclude from the next iteration
-
- float wt_sum = 0.f;
-
- float output_weights[experts_per_thread];
-
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
int max_expert = threadIdx.x;
}
}
+ if constexpr (delayed_softmax) {
+ softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
+ }
+
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x;
}
}
-template <bool with_norm>
+template <bool with_norm, bool delayed_softmax = false>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits,
float * weights,
const int n_rows,
const int n_expert,
const int n_expert_used) {
+ static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
+
const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
switch (n_expert) {
case 1:
- topk_moe_cuda<1, with_norm>
+ topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 2:
- topk_moe_cuda<2, with_norm>
+ topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 4:
- topk_moe_cuda<4, with_norm>
+ topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 8:
- topk_moe_cuda<8, with_norm>
+ topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 16:
- topk_moe_cuda<16, with_norm>
+ topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 32:
- topk_moe_cuda<32, with_norm>
+ topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 64:
- topk_moe_cuda<64, with_norm>
+ topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 128:
- topk_moe_cuda<128, with_norm>
+ topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 256:
- topk_moe_cuda<256, with_norm>
+ topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
case 512:
- topk_moe_cuda<512, with_norm>
+ topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
break;
default:
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
- const bool with_norm) {
+ const bool with_norm,
+ const bool delayed_softmax) {
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1];
- const float * logits_d = (const float *) logits->src[0]->data;
+ const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data;
if (with_norm) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
} else {
- launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+ if (delayed_softmax) {
+ launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+ } else {
+ launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
+ }
}
}
return true;
}
-std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
+std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
+ static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+ GGML_ASSERT(!norm || !delayed_softmax);
+
+ if (delayed_softmax) {
+ return delayed_softmax_ops;
+ }
+
if (norm) {
return norm_ops;
}
+
return no_norm_ops;
}
const std::array<int64_t, 4> ne;
const int n_expert_used;
const bool with_norm;
- test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false)
- : ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) {
+ const bool delayed_softmax;
+
+ test_topk_moe(std::array<int64_t, 4> ne = { 10, 5, 1, 1 },
+ int n_expert_used = 1,
+ bool with_norm = false,
+ bool delayed_softmax = false) :
+ ne(ne),
+ n_expert_used(n_expert_used),
+ with_norm(with_norm),
+ delayed_softmax(delayed_softmax) {
GGML_ASSERT(n_expert_used <= ne[0]);
+ GGML_ASSERT(!(with_norm && delayed_softmax));
}
- std::string vars() override {
- return VARS_TO_STR3(ne, n_expert_used, with_norm);
- }
+ std::string vars() override { return VARS_TO_STR4(ne, n_expert_used, with_norm, delayed_softmax); }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
const int n_tokens = ne[1];
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
- ggml_tensor * probs = ggml_soft_max(ctx, logits);
+ ggml_tensor * probs = delayed_softmax ? logits : ggml_soft_max(ctx, logits);
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+ if (delayed_softmax) {
+ out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
+ out = ggml_soft_max(ctx, out); // [n_expert_used, n_tokens]
+ out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
+ }
+
if (with_norm) {
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
}
+ test_cases.emplace_back(new test_topk_moe({ 8, 22, 1, 1 }, 4, /*with_norm*/ false, /*delayed_softmax*/ true));
+ test_cases.emplace_back(new test_topk_moe({ 32, 22, 1, 1 }, 8, /*with_norm*/ false, /*delayed_softmax*/ true));
+
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));