]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda/vulkan: specify fp32-only support for some operations in supports_op (ggml/1129)
authorcmdr2 <redacted>
Fri, 28 Feb 2025 10:36:46 +0000 (12:36 +0200)
committerGeorgi Gerganov <redacted>
Mon, 3 Mar 2025 16:18:11 +0000 (18:18 +0200)
ggml-ci

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-vulkan/ggml-vulkan.cpp
tests/test-backend-ops.cpp

index a28b7037bd3a05348db9a5fca67475b254e13166..b5d2c84111e60aee1cb1265fa25f84f541d5d02b 100644 (file)
@@ -3155,7 +3155,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return false;
             } break;
         case GGML_OP_SILU_BACK:
-            return ggml_is_contiguous(op->src[0]);
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
             break;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
index a413441ebda94bd7e044862d2403a8a1f3c1b077..ff53bdfbe171c42259036e1c2152ead4fab4fa87 100644 (file)
@@ -8452,7 +8452,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_SIGMOID:
-                    return ggml_is_contiguous(op->src[0]);
+                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
             }
@@ -8653,19 +8653,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_RMS_NORM:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
-        case GGML_OP_ACC:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
-        case GGML_OP_CONCAT:
         case GGML_OP_SILU_BACK:
         case GGML_OP_RMS_NORM_BACK:
-        case GGML_OP_UPSCALE:
-        case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_SIN:
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
+            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ACC:
+        case GGML_OP_CONCAT:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_SCALE:
         case GGML_OP_PAD:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
index c393bf16b8a1f6e29a479b2fcf050d6a9c392ca8..b4e3631ed081a062d92136fb1fc39444f305cdf3 100644 (file)
@@ -3998,10 +3998,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 
     test_cases.emplace_back(new test_add1());
     test_cases.emplace_back(new test_scale());
-
-    for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
-        test_cases.emplace_back(new test_silu_back());
-    }
+    test_cases.emplace_back(new test_silu_back());
 
     for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
         for (bool v : {false, true}) {