]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : rename missed batch params/vars to ubatch (#10059)
authorDaniel Bevenius <redacted>
Mon, 6 Jan 2025 09:28:17 +0000 (10:28 +0100)
committerGitHub <redacted>
Mon, 6 Jan 2025 09:28:17 +0000 (11:28 +0200)
This commit renames the `batch` parameter to `ubatch` in the
`llama_kv_cache_find_slot`, `llm_build_inp_embd`, and
`llm_build_mamba` functions.

The motivation for this is that this should have been done as part of
Commit 19d900a7565b8f6b0a708836a57d26966cb9efe2 ("llama : rename batch
to ubatch (#9950)") but for some reason I missed these functions in
that commit and only noticed them now (sorry).

src/llama-kv-cache.cpp
src/llama.cpp

index 53379253a3cace999935b09b94bac715d9a2ce5b..90b6c56ed068c84564e368b2ac6ccdffd36a63f2 100644 (file)
@@ -119,10 +119,10 @@ bool llama_kv_cache_init(
 
 struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
            struct llama_kv_cache & cache,
-       const struct llama_ubatch & batch) {
-    const uint32_t n_tokens = batch.n_tokens;
-    const uint32_t n_seqs   = batch.n_seqs;
-    const uint32_t n_seq_tokens = batch.n_seq_tokens;
+       const struct llama_ubatch & ubatch) {
+    const uint32_t n_tokens = ubatch.n_tokens;
+    const uint32_t n_seqs   = ubatch.n_seqs;
+    const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
 
     if (cache.recurrent) {
         // For recurrent state architectures (like Mamba or RWKV),
@@ -130,16 +130,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
         // A slot should be always be contiguous.
 
         // can only process batches with an equal number of new tokens in each sequence
-        GGML_ASSERT(batch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs);
 
         int32_t min = cache.size - 1;
         int32_t max = 0;
 
         // everything should fit if all seq_ids are smaller than the max
         for (uint32_t s = 0; s < n_seqs; ++s) {
-            const uint32_t n_seq_id = batch.n_seq_id[s];
+            const uint32_t n_seq_id = ubatch.n_seq_id[s];
             for (uint32_t j = 0; j < n_seq_id; ++j) {
-                const llama_seq_id seq_id = batch.seq_id[s][j];
+                const llama_seq_id seq_id = ubatch.seq_id[s][j];
 
                 if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
                     // too big seq_id
@@ -198,7 +198,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
 
         // find usable cell range
         for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
             llama_kv_cell & seq_meta = cache.cells[seq_id];
             bool has_cell = false;
             if (seq_meta.tail >= 0) {
@@ -237,7 +237,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
         // gather and re-order
         for (uint32_t s = 0; s < n_seqs; ++s) {
             int32_t dst_id = s + min;
-            int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
+            int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail;
             if (dst_id != src_id) {
                 llama_kv_cell & dst_cell = cache.cells[dst_id];
                 llama_kv_cell & src_cell = cache.cells[src_id];
@@ -258,7 +258,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
 
         // update the pos of the used seqs
         for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
+            const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
             int32_t cell_id = s + min;
             llama_kv_cell & cell = cache.cells[cell_id];
 
@@ -266,12 +266,12 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
                 // What should happen when the pos backtracks or skips a value?
                 // Clearing the state mid-batch would require special-casing which isn't done.
                 LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
-                    __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
+                    __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
             }
             cell.pos = last_pos;
             cell.seq_id.clear();
-            for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
-                const llama_seq_id seq_id = batch.seq_id[s][j];
+            for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
+                const llama_seq_id seq_id = ubatch.seq_id[s][j];
                 cell.seq_id.insert(seq_id);
                 cache.cells[seq_id].tail = cell_id;
             }
@@ -325,10 +325,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
     for (uint32_t s = 0; s < n_seqs; s++) {
         for (uint32_t i = 0; i < n_seq_tokens; ++i) {
             uint32_t k = s*n_seq_tokens + i;
-            cache.cells[cache.head + k].pos = batch.pos[k];
+            cache.cells[cache.head + k].pos = ubatch.pos[k];
 
-            for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
-                cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
+            for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) {
+                cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]);
             }
         }
     }
index 7337c34ce573ebf1e3b4c3fad47764d33447f985..60728e5bb91cabcaa504362d5b37c95aa23ecf98 100644 (file)
@@ -2540,21 +2540,21 @@ static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
        struct llama_context & lctx,
         const llama_hparams & hparams,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_tensor * tok_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
 
     struct ggml_tensor * inpL;
 
-    if (batch.token) {
-        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
+    if (ubatch.token) {
+        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
         cb(lctx.inp_tokens, "inp_tokens", -1);
         ggml_set_input(lctx.inp_tokens);
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
     } else {
-        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
         inpL = lctx.inp_embd;
         ggml_set_input(lctx.inp_embd);
     }
@@ -3149,7 +3149,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
 static struct ggml_tensor * llm_build_mamba(
         struct ggml_context * ctx,
        struct llama_context & lctx,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_cgraph * graph,
          struct ggml_tensor * cur,
          struct ggml_tensor * state_copy,
@@ -3165,17 +3165,17 @@ static struct ggml_tensor * llm_build_mamba(
     const int64_t d_inner = hparams.ssm_d_inner;
     const int64_t d_state = hparams.ssm_d_state;
     const int64_t dt_rank = hparams.ssm_dt_rank;
-    const int64_t n_seqs  = batch.n_seqs;
+    const int64_t n_seqs  = ubatch.n_seqs;
     // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
     const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
     // Use the same RMS norm as the final layer norm
     const float norm_rms_eps = hparams.f_norm_rms_eps;
 
-    const int64_t n_seq_tokens = batch.n_seq_tokens;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
     GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(batch.equal_seqs);
-    GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
+    GGML_ASSERT(ubatch.equal_seqs);
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
     struct ggml_tensor * conv_states_all = kv.k_l[il];
     struct ggml_tensor * ssm_states_all  = kv.v_l[il];