]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : fix supports_op condition for get_rows when number of blocks is too large...
authorGeorgi Gerganov <redacted>
Mon, 8 Sep 2025 10:56:51 +0000 (13:56 +0300)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* cuda : fix supports_op condition for get_rows when src1->ne2 > 1

ggml-ci

* ggml : add comment about ggml_get_rows

ggml-ci

* cuda : add FIXME [no ci]

* cuda : update support condition

ggml-ci

include/ggml.h
src/ggml-cuda/ggml-cuda.cu
src/ggml.c

index 058f4267f754402d1028f0191fb717e1a8469b82..b7b472c56ec61d3488bc20bd7e0e5ce89c443b7b 100644 (file)
@@ -1529,7 +1529,11 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    // supports 3D: a->ne[2] == b->ne[1]
+    // supports 4D a:
+    // a     [n_embd, ne1, ne2, ne3]
+    // b I32 [n_rows, ne2, ne3, 1]
+    //
+    // return [n_embd, n_rows, ne2, ne3]
     GGML_API struct ggml_tensor * ggml_get_rows(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,  // data
index a88b9f75ef2982b3835d8346f6d70957b9d1b947..0c6bd363961f1c49fd805115c578a8a284624164 100644 (file)
@@ -3392,6 +3392,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_GET_ROWS:
             {
+                // FIXME: https://github.com/ggml-org/llama.cpp/pull/15868
+                if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) {
+                    return false;
+                }
                 switch (op->src[0]->type) {
                     case GGML_TYPE_F16:
                     case GGML_TYPE_F32:
index f35c337952ec398e3a90a4cf14c27922c9403615..50dc1aa24fff58938dcda094fac398e1a2126ce2 100644 (file)
@@ -3623,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b) {
     GGML_ASSERT(a->ne[2] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == b->ne[2]);
     GGML_ASSERT(b->ne[3] == 1);
     GGML_ASSERT(b->type == GGML_TYPE_I32);