]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: enable Gemma FA for HIP/Pascal (llama/9581)
authorJohannes Gäßler <redacted>
Sun, 22 Sep 2024 07:34:52 +0000 (09:34 +0200)
committerGeorgi Gerganov <redacted>
Tue, 24 Sep 2024 10:04:37 +0000 (13:04 +0300)
src/ggml-cuda.cu
src/ggml-cuda/fattn.cu
tests/test-backend-ops.cpp

index f940511980e59fe23ec05b5eb070fdb7133becb9..bf21c643d3a6b7720ff62eae2566ea72501c274c 100644 (file)
@@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV:
             return true;
-        case GGML_OP_FLASH_ATTN_EXT:
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-            return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
-#else
+        case GGML_OP_FLASH_ATTN_EXT: {
+            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
+                return true;
+            }
             if (op->src[0]->ne[0] == 128) {
                 return true;
             }
-            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
+            if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
-            return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
-                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
+            const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
+            return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+        }
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW:
index f28a19d40b35617d34ff42c126c1d8bce2010433..83e5589a1cc244e3182a183f580b87fb13d0adb6 100644 (file)
@@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (!fast_fp16_available(cc)) {
-        if (Q->ne[1] <= 8) {
+        if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
         } else {
             ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
index efa88688cde75fab5c631d5eebec1e0beb4037a4..9a96cfc4c99de05e48de79dd929f4039f60ac3ea 100644 (file)
@@ -3599,7 +3599,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                     if (hs != 128 && logit_softcap != 0.0f) continue;
                     for (int nh : { 32, }) {
                         for (int kv : { 512, 1024, }) {
-                            for (int nb : { 1, 2, 4, 8, }) {
+                            for (int nb : { 1, 3, 32, 35, }) {
                                 for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                     test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
                                 }