]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
graph : prepare for 4D mask (#14515)
authorGeorgi Gerganov <redacted>
Fri, 4 Jul 2025 06:05:36 +0000 (09:05 +0300)
committerGitHub <redacted>
Fri, 4 Jul 2025 06:05:36 +0000 (09:05 +0300)
ggml-ci

src/llama-graph.cpp
src/llama-graph.h

index 4443420132f2b5860a02ce897268af5f7377a012..7f0e8c67f13253c27b9154551ece542df69ecf89 100644 (file)
@@ -1005,8 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
         inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1143,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
     auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
 
     // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
-    inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-    //cb(inp_kq_mask, "KQ_mask", -1);
+    inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
     ggml_set_input(inp->kq_mask);
 
     inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1209,7 +1207,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
         inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1343,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
 
     const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
 
-    inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+    inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
     ggml_set_input(inp->cross_kq_mask);
 
     inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1457,7 +1455,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
         inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
         ggml_set_input(inp->self_kq_mask);
 
         inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1471,7 +1469,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
         inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
         inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
 
-        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
         ggml_set_input(inp->self_kq_mask_swa);
 
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
index c8b74a14741e2677aeeadddecd2546f8ecd09e51..7bdf656768a0c1d5fda6e2cd4645b3ca2a5eb492 100644 (file)
@@ -228,8 +228,8 @@ public:
 
     ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
 
-    ggml_tensor * kq_mask     = nullptr; // F32 [n_tokens, n_batch]
-    ggml_tensor * kq_mask_cnv = nullptr; //     [n_tokens, n_batch]
+    ggml_tensor * kq_mask     = nullptr; // F32 [n_tokens, n_batch, 1, 1]
+    ggml_tensor * kq_mask_cnv = nullptr; //     [n_tokens, n_batch, 1, 1]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
@@ -257,8 +257,8 @@ public:
     ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
     ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
 
-    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch]
-    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch]
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
@@ -293,10 +293,10 @@ public:
     ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
     ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
 
-    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch]
-    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch]
-    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch]
-    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch]
+    ggml_tensor * self_kq_mask         = nullptr; // F32 [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask_cnv     = nullptr; //     [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask_swa     = nullptr; // F32 [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask_swa_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;
@@ -313,8 +313,8 @@ public:
 
     ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
 
-    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch]
-    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
+    ggml_tensor * cross_kq_mask     = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
+    ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
 
     const llama_cross * cross = nullptr;
 };
@@ -343,8 +343,8 @@ public:
     ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
     ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
 
-    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch]
-    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch]
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch, 1, 1]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch, 1, 1]
 
     const llama_hparams & hparams;
     const llama_cparams & cparams;