From: Georgi Gerganov Date: Wed, 1 Apr 2026 08:10:25 +0000 (+0300) Subject: ggml : fix RWKV ops thread assignment (llama/21226) X-Git-Tag: v0.9.10~4 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=9695adb06458e4feb283821818a287df35b38f31;p=pkg%2Fggml%2Fsources%2Fggml ggml : fix RWKV ops thread assignment (llama/21226) --- diff --git a/src/ggml-cpu/ggml-cpu.c b/src/ggml-cpu/ggml-cpu.c index df17cc55..7486acc2 100644 --- a/src/ggml-cpu/ggml-cpu.c +++ b/src/ggml-cpu/ggml-cpu.c @@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: { - n_tasks = n_threads; + const int64_t n_heads = node->src[1]->ne[1]; + n_tasks = MIN(n_threads, n_heads); } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: diff --git a/src/ggml-cpu/ops.cpp b/src/ggml-cpu/ops.cpp index d950972c..765ce07f 100644 --- a/src/ggml-cpu/ops.cpp +++ b/src/ggml-cpu/ops.cpp @@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * r = (float *) dst->src[0]->data; float * w = (float *) dst->src[1]->data;