ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
+
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
}
}
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+ const ggml_tensor * weights,
+ const ggml_tensor * get_rows,
+ const ggml_tensor * argsort,
+ const ggml_tensor * clamp,
+ int n_expert) {
+ ggml_tensor * probs = get_rows->src[0];
+ if (probs->op != GGML_OP_RESHAPE) {
+ return false;
+ }
+ probs = probs->src[0];
+ ggml_tensor * selection_probs = argsort->src[0];
+
+ if (probs != selection_probs) {
+ return false;
+ }
+
float scale = 1.0f;
float max_bias = 0.0f;
return false;
}
- const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+ const ggml_tensor * weights,
+ const ggml_tensor * get_rows,
+ const ggml_tensor * argsort,
+ const ggml_tensor * clamp,
+ int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
const ggml_tensor * softmax;
const ggml_tensor * weights;
+ const ggml_tensor * get_rows;
+ const ggml_tensor * argsort;
switch (mode) {
case TOPK_MOE_EARLY_SOFTMAX_NORM:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9];
+ get_rows = cgraph->nodes[node_idx + 4];
+ argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4];
+ get_rows = cgraph->nodes[node_idx + 4];
+ argsort = cgraph->nodes[node_idx + 2];
break;
case TOPK_MOE_LATE_SOFTMAX:
softmax = cgraph->nodes[node_idx + 4];
weights = cgraph->nodes[node_idx + 5];
+ get_rows = cgraph->nodes[node_idx + 2];
+ argsort = cgraph->nodes[node_idx + 0];
break;
default:
return false;
}
+ ggml_tensor * probs = get_rows->src[0];
+ if (probs->op != GGML_OP_RESHAPE) {
+ return false;
+ }
+ probs = probs->src[0];
+ ggml_tensor * selection_probs = argsort->src[0];
+
+ if (probs != selection_probs) {
+ return false;
+ }
+
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];