]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : update supports_op for matrix multiplication (llama/8245)
authorslaren <redacted>
Tue, 2 Jul 2024 06:39:38 +0000 (08:39 +0200)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 10:03:28 +0000 (13:03 +0300)
src/ggml-cuda.cu
tests/test-backend-ops.cpp

index 16ca3c69577aab1a4c1eb2b7b60930ecf322ab16..ed784ea1c20246d7c6854092a9507dd53a69b84c 100644 (file)
@@ -2715,27 +2715,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             {
-                struct ggml_tensor * a;
-                struct ggml_tensor * b;
+                struct ggml_tensor * a = op->src[0];
                 if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
-                if (a->ne[3] != b->ne[3]) {
-                    return false;
-                }
-                ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
-                    a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S   ||
-                    a_type == GGML_TYPE_IQ1_M   || a_type == GGML_TYPE_IQ2_S  || a_type == GGML_TYPE_IQ4_XS) {
-                    if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
+                    struct ggml_tensor * b = op->src[1];
+                    if (a->ne[3] != b->ne[3]) {
                         return false;
                     }
                 }
-                return true;
+                switch (a->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_Q2_K:
+                    case GGML_TYPE_Q3_K:
+                    case GGML_TYPE_Q4_K:
+                    case GGML_TYPE_Q5_K:
+                    case GGML_TYPE_Q6_K:
+                    case GGML_TYPE_Q8_K:
+                    case GGML_TYPE_IQ1_M:
+                    case GGML_TYPE_IQ1_S:
+                    case GGML_TYPE_IQ2_S:
+                    case GGML_TYPE_IQ2_XS:
+                    case GGML_TYPE_IQ2_XXS:
+                    case GGML_TYPE_IQ3_S:
+                    case GGML_TYPE_IQ3_XXS:
+                    case GGML_TYPE_IQ4_NL:
+                    case GGML_TYPE_IQ4_XS:
+                        return true;
+                    default:
+                        return false;
+                }
             } break;
         case GGML_OP_GET_ROWS:
             {
index c615f60d8ccfe3d43dda17d1039a3a166a3b8ce7..83d4a0eb153cf4c34860e8c10b2d7e6f946a7366 100644 (file)
@@ -2082,6 +2082,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
         GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
         GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+        GGML_TYPE_BF16,
     };
 
     // unary ops