]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
graph : clean up t5 input builders (#18795)
authorGabe Goodhart <redacted>
Tue, 13 Jan 2026 08:43:51 +0000 (01:43 -0700)
committerGitHub <redacted>
Tue, 13 Jan 2026 08:43:51 +0000 (09:43 +0100)
* fix: Remove unnecessary `h` loops where `h` was only ever 0

Branch: CleanUpT5InputBuilders

Signed-off-by: Gabe Goodhart <redacted>
* fix: Remove unnecessary padding loop that is never hit anymore

The upper bound used to use GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), but was
removed in https://github.com/ggml-org/llama.cpp/pull/17910 leaving the
loop dead.

Branch: CleanUpT5InputBuilders

Signed-off-by: Gabe Goodhart <redacted>
---------

Signed-off-by: Gabe Goodhart <redacted>
src/llama-graph.cpp

index 374ff1ebf3a2a41d498b6a0476593272b0707518..944c7e53bd2b711016d63f0fa7e9d37e0b813487 100644 (file)
@@ -96,11 +96,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
 
         int32_t * data = (int32_t *) pos_bucket->data;
 
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                for (int i = 0; i < n_tokens; ++i) {
-                    data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
-                }
+        for (int j = 0; j < n_tokens; ++j) {
+            for (int i = 0; i < n_tokens; ++i) {
+                data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true);
             }
         }
     }
@@ -323,34 +321,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
     const int64_t n_tokens = ubatch->n_tokens;
 
     const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) {
-        for (int h = 0; h < 1; ++h) {
-            for (int i1 = 0; i1 < n_tokens; ++i1) {
-                const llama_seq_id s1 = ubatch->seq_id[i1][0];
-                const llama_pos    p1 = ubatch->pos[i1];
+        for (int i1 = 0; i1 < n_tokens; ++i1) {
+            const llama_seq_id s1 = ubatch->seq_id[i1][0];
+            const llama_pos    p1 = ubatch->pos[i1];
 
-                const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv;
+            const uint64_t idst = i1*n_kv;
 
-                for (int i0 = 0; i0 < n_tokens; ++i0) {
-                    const llama_seq_id s0 = ubatch->seq_id[i0][0];
-                    const llama_pos p0    = ubatch->pos[i0];
-
-                    // mask different sequences
-                    if (s0 != s1) {
-                        continue;
-                    }
+            for (int i0 = 0; i0 < n_tokens; ++i0) {
+                const llama_seq_id s0 = ubatch->seq_id[i0][0];
+                const llama_pos p0    = ubatch->pos[i0];
 
-                    // mask future tokens
-                    if (cparams.causal_attn && p0 > p1) {
-                        continue;
-                    }
+                // mask different sequences
+                if (s0 != s1) {
+                    continue;
+                }
 
-                    // apply SWA if any
-                    if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
-                        continue;
-                    }
+                // mask future tokens
+                if (cparams.causal_attn && p0 > p1) {
+                    continue;
+                }
 
-                    data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
+                // apply SWA if any
+                if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) {
+                    continue;
                 }
+
+                data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
             }
         }
     };
@@ -454,27 +450,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 
     float * data = (float *) cross_kq_mask->data;
 
-    for (int h = 0; h < 1; ++h) {
-        for (int i = 0; i < n_tokens; ++i) {
-            for (int j = 0; j < n_enc; ++j) {
-                float f = -INFINITY;
+    for (int i = 0; i < n_tokens; ++i) {
+        for (int j = 0; j < n_enc; ++j) {
+            float f = -INFINITY;
 
-                for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
-                    const llama_seq_id seq_id = ubatch->seq_id[i][s];
+            for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
+                const llama_seq_id seq_id = ubatch->seq_id[i][s];
 
-                    if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
-                        f = 0.0f;
-                    }
+                if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
+                    f = 0.0f;
                 }
-
-                data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
             }
-        }
 
-        for (int i = n_tokens; i < n_tokens; ++i) {
-            for (int j = 0; j < n_enc; ++j) {
-                data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
-            }
+            data[i*n_enc + j] = f;
         }
     }
 }