]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : fix ggml_flash_attn to use op_params (#2387)
authorslaren <redacted>
Tue, 25 Jul 2023 14:20:12 +0000 (16:20 +0200)
committerGitHub <redacted>
Tue, 25 Jul 2023 14:20:12 +0000 (16:20 +0200)
* ggml : fix ggml_flash_attn to use op_params

ggml.c

diff --git a/ggml.c b/ggml.c
index 33a6ffdc6e62cdce758fbb63459f3c4441ad11b3..35c56151b8f7c7bbbe9ca5f4b7e4372a78cc173d 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -7030,14 +7030,16 @@ struct ggml_tensor * ggml_flash_attn(
     }
 
     //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne);
+
+    int32_t t = masked ? 1 : 0;
+    ggml_set_op_params(result, &t, sizeof(t));
 
     result->op   = GGML_OP_FLASH_ATTN;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = q;
     result->src[1] = k;
     result->src[2] = v;
-    result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
 
     return result;
 }
@@ -7061,7 +7063,7 @@ struct ggml_tensor * ggml_flash_ff(
     }
 
     //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
-    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne);
 
     result->op   = GGML_OP_FLASH_FF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -7127,13 +7129,15 @@ struct ggml_tensor * ggml_flash_attn_back(
 
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
+    int32_t masked_i = masked ? 1 : 0;
+    ggml_set_op_params(result, &masked_i, sizeof(masked_i));
+
     result->op   = GGML_OP_FLASH_ATTN_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = q;
     result->src[1] = k;
     result->src[2] = v;
     result->src[3] = d;
-    result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
 
     return result;
 }
@@ -14773,7 +14777,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_FLASH_ATTN:
             {
-                const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
+                const int32_t t = ggml_get_op_params_i32(tensor, 0);
                 GGML_ASSERT(t == 0 || t == 1);
                 const bool masked = t != 0;
                 ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
@@ -14784,7 +14788,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_FLASH_ATTN_BACK:
             {
-                int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
+                int32_t t = ggml_get_op_params_i32(tensor, 0);
                 GGML_ASSERT(t == 0 || t == 1);
                 bool masked = t != 0;
                 ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
@@ -15402,7 +15406,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 struct ggml_tensor * flash_grad = NULL;
                 if (src0->grad || src1->grad || tensor->src[2]->grad) {
-                    int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
+                    int32_t t = ggml_get_op_params_i32(tensor, 0);
                     GGML_ASSERT(t == 0 || t == 1);
                     bool masked = t != 0;
                     flash_grad =