]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda/vulkan: specify fp32-only support for some operations in supports_op (ggml/1129)
authorcmdr2 <redacted>
Fri, 28 Feb 2025 10:29:55 +0000 (15:59 +0530)
committerGeorgi Gerganov <redacted>
Sat, 8 Mar 2025 13:13:01 +0000 (15:13 +0200)
* cuda: restrict SILU_BACK to fp32, since fp16 exceeds the desired test threshold

* vulkan: specify fp32-only support for certain ops (that are now tested for fp16 as well)

* f32 sigmoid in vulkan supports op

* Revert "f32 sigmoid in vulkan supports op"

This reverts commit c6f04b3c19bf4504c2776149c6d8cd84e0b48acb.

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-vulkan/ggml-vulkan.cpp

index d49480573dd6af3fec239b8aa1f367cfa3d92b15..fe30259f78230e63bdccdf973745c8562be696ef 100644 (file)
@@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return false;
             } break;
         case GGML_OP_SILU_BACK:
-            return ggml_is_contiguous(op->src[0]);
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
             break;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
index 131ee1ea044dd9df6e93d7c3bae5ae5368d8877d..e0066c3642501be3bf01f7feb5c4768062978adf 100644 (file)
@@ -8371,7 +8371,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_TANH:
-                    return ggml_is_contiguous(op->src[0]);
+                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
             }
@@ -8571,17 +8571,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_RMS_NORM:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
-        case GGML_OP_ACC:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
-        case GGML_OP_CONCAT:
-        case GGML_OP_UPSCALE:
-        case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_SIN:
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
+            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_ACC:
+        case GGML_OP_CONCAT:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_SCALE:
         case GGML_OP_PAD:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX: