]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : update supports_op for matrix multiplication (#8245)
authorslaren <redacted>
Tue, 2 Jul 2024 06:39:38 +0000 (08:39 +0200)
committerGitHub <redacted>
Tue, 2 Jul 2024 06:39:38 +0000 (09:39 +0300)
ggml/src/ggml-cuda.cu
tests/test-backend-ops.cpp

index 649ef5a081910734a0d7f46d7956edc241e4c528..1c9ccc8a15e54e61540b3ffc63b4af5fd609a8c8 100644 (file)
@@ -2711,27 +2711,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             {
-                struct ggml_tensor * a;
-                struct ggml_tensor * b;
+                struct ggml_tensor * a = op->src[0];
                 if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
-                if (a->ne[3] != b->ne[3]) {
-                    return false;
-                }
-                ggml_type a_type = a->type;
-                if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
-                    a_type == GGML_TYPE_IQ1_S   || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S   ||
-                    a_type == GGML_TYPE_IQ1_M   || a_type == GGML_TYPE_IQ2_S  || a_type == GGML_TYPE_IQ4_XS) {
-                    if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
+                    struct ggml_tensor * b = op->src[1];
+                    if (a->ne[3] != b->ne[3]) {
                         return false;
                     }
                 }
-                return true;
+                switch (a->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_Q2_K:
+                    case GGML_TYPE_Q3_K:
+                    case GGML_TYPE_Q4_K:
+                    case GGML_TYPE_Q5_K:
+                    case GGML_TYPE_Q6_K:
+                    case GGML_TYPE_Q8_K:
+                    case GGML_TYPE_IQ1_M:
+                    case GGML_TYPE_IQ1_S:
+                    case GGML_TYPE_IQ2_S:
+                    case GGML_TYPE_IQ2_XS:
+                    case GGML_TYPE_IQ2_XXS:
+                    case GGML_TYPE_IQ3_S:
+                    case GGML_TYPE_IQ3_XXS:
+                    case GGML_TYPE_IQ4_NL:
+                    case GGML_TYPE_IQ4_XS:
+                        return true;
+                    default:
+                        return false;
+                }
             } break;
         case GGML_OP_GET_ROWS:
             {
index f74c0db475e2e2512fe1e41bab961fde4187a82c..2bb71ac03817fd1b18577227276ff6affb20a607 100644 (file)
@@ -2052,6 +2052,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
         GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
         GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+        GGML_TYPE_BF16,
     };
 
     // unary ops