From: cmdr2 Date: Fri, 28 Feb 2025 10:29:55 +0000 (+0530) Subject: cuda/vulkan: specify fp32-only support for some operations in supports_op (ggml/1129) X-Git-Tag: upstream/1.7.4+203~33 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=6ac8e6b2cee599fab8afa1714f44b83dabb50dca;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp cuda/vulkan: specify fp32-only support for some operations in supports_op (ggml/1129) * cuda: restrict SILU_BACK to fp32, since fp16 exceeds the desired test threshold * vulkan: specify fp32-only support for certain ops (that are now tested for fp16 as well) * f32 sigmoid in vulkan supports op * Revert "f32 sigmoid in vulkan supports op" This reverts commit c6f04b3c19bf4504c2776149c6d8cd84e0b48acb. --- diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d4948057..fe30259f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } break; case GGML_OP_SILU_BACK: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 131ee1ea..e0066c36 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8371,7 +8371,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; } @@ -8571,17 +8571,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RMS_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: - case GGML_OP_ACC: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: