]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: enable Gemma FA for HIP/Pascal (llama/9581)
authorJohannes Gäßler <redacted>
Sun, 22 Sep 2024 07:34:52 +0000 (09:34 +0200)
committerGeorgi Gerganov <redacted>
Tue, 24 Sep 2024 16:45:08 +0000 (19:45 +0300)
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/fattn.cu

index f940511980e59fe23ec05b5eb070fdb7133becb9..bf21c643d3a6b7720ff62eae2566ea72501c274c 100644 (file)
@@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV:
             return true;
-        case GGML_OP_FLASH_ATTN_EXT:
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-            return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
-#else
+        case GGML_OP_FLASH_ATTN_EXT: {
+            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
+                return true;
+            }
             if (op->src[0]->ne[0] == 128) {
                 return true;
             }
-            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
+            if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
-            return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
-                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+            const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
+            return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+        }
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
index f28a19d40b35617d34ff42c126c1d8bce2010433..83e5589a1cc244e3182a183f580b87fb13d0adb6 100644 (file)
@@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (!fast_fp16_available(cc)) {
-        if (Q->ne[1] <= 8) {
+        if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
         } else {
             ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);