]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix RWKV ops thread assignment (llama/21226)
authorGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 08:10:25 +0000 (11:10 +0300)
committerGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 13:00:26 +0000 (16:00 +0300)
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp

index df17cc553006472a3c0f9c4d728e2c60a9f2c2bf..7486acc2b5d6fde5abd54d09ac1a7165c1342850 100644 (file)
@@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_FLASH_ATTN_BACK:
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_RWKV_WKV7:
             {
-                n_tasks = n_threads;
+                const int64_t n_heads = node->src[1]->ne[1];
+                n_tasks = MIN(n_threads, n_heads);
             } break;
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
index d950972c83e9e2df33a9a8f57b959b4813755c3a..765ce07f06c24bbd9d7c63647fddf215aeb8f0da 100644 (file)
@@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    if (ith >= HEADS) {
-        return;
-    }
-
-    const int h_start = (HEADS * ith) / nth;
-    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
-                (HEADS * (ith + 1)) / nth : HEADS;
+    const int h_start =  (HEADS * (ith    )) / nth;
+    const int h_end   = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                         (HEADS * (ith + 1)) / nth : HEADS;
 
     float * k =          (float *) dst->src[0]->data;
     float * v =          (float *) dst->src[1]->data;
@@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    if (ith >= HEADS) {
-        return;
-    }
-
-    const int h_start = (HEADS * ith) / nth;
-    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
-                (HEADS * (ith + 1)) / nth : HEADS;
+    const int h_start =  (HEADS * (ith    )) / nth;
+    const int h_end   = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                         (HEADS * (ith + 1)) / nth : HEADS;
 
     float * k = (float *) dst->src[0]->data;
     float * v = (float *) dst->src[1]->data;
@@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    if (ith >= HEADS) {
-        return;
-    }
-
-    const int h_start = (HEADS * ith) / nth;
-    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
-                (HEADS * (ith + 1)) / nth : HEADS;
+    const int h_start =  (HEADS * (ith    )) / nth;
+    const int h_end   = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                         (HEADS * (ith + 1)) / nth : HEADS;
 
     float * r = (float *) dst->src[0]->data;
     float * w = (float *) dst->src[1]->data;