]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: refactor llama_decode_impl (#11381)
authorJohannes Gäßler <redacted>
Mon, 27 Jan 2025 11:07:12 +0000 (12:07 +0100)
committerGitHub <redacted>
Mon, 27 Jan 2025 11:07:12 +0000 (12:07 +0100)
src/llama.cpp

index 094157ccf2aa231e4b41e1b9d9e9e4f7f0a13c52..12e8f41fc8614951a392bd41690a0c0f13e39a25 100644 (file)
@@ -8432,74 +8432,33 @@ static enum ggml_status llama_graph_compute(
     return status;
 }
 
-// decode a batch of tokens by evaluating the transformer
-// in case of unsuccessful decoding (error or warning),
-// the kv_cache state will be returned to its original state
-// (for non-recurrent models) or cleaned (for recurrent models)
-//
-//   - lctx:      llama context
-//   - batch:     batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_decode_impl(
-         llama_context & lctx,
-           llama_batch   inp_batch) {
-
-    lctx.is_encoding = false;
-
-    if (inp_batch.n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
-        return -1;
-    }
-
-    // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
-
-    const llama_batch & batch = batch_allocr.batch;
-    const uint32_t n_tokens_all = batch.n_tokens;
-
+static int llama_prepare_sbatch(
+        llama_context     & lctx,
+        const llama_batch & batch,
+        uint32_t          & n_outputs) {
     const auto & model   = lctx.model;
-    const auto & vocab   = model.vocab;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+    const uint32_t n_tokens_all = batch.n_tokens;
+    const  int64_t n_embd       = hparams.n_embd;
+
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
 
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
     if (batch.token) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
+            if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
         }
     }
-
     GGML_ASSERT(n_tokens_all <= cparams.n_batch);
-
     GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
 
-    if (lctx.t_compute_start_us == 0) {
-        lctx.t_compute_start_us = ggml_time_us();
-    }
     lctx.n_queued_tokens += n_tokens_all;
-
-    auto & kv_self = lctx.kv_self;
-    llama_kv_slot_restorer kv_slot_restorer(kv_self);
-
-    const int64_t n_embd  = hparams.n_embd;
-    const int64_t n_vocab = vocab.n_tokens();
-
-    uint32_t n_outputs = 0;
-    uint32_t n_outputs_prev = 0;
-
-    const auto n_ubatch = cparams.n_ubatch;
-
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
     lctx.embd_seq.clear();
 
     // count outputs
@@ -8515,7 +8474,7 @@ static int llama_decode_impl(
     }
 
     lctx.sbatch.from_batch(batch, n_embd,
-        /* simple_split */ !kv_self.recurrent,
+        /* simple_split */ !lctx.kv_self.recurrent,
         /* logits_all   */ n_outputs == n_tokens_all);
 
     // reserve output buffer
@@ -8524,70 +8483,148 @@ static int llama_decode_impl(
         return -2;
     };
 
-    while (lctx.sbatch.n_tokens > 0) {
-        llama_ubatch ubatch;
-        if (kv_self.recurrent) {
-            if (embd_pooled) {
-                // Pooled embeddings cannot be split across ubatches (yet)
-                ubatch = lctx.sbatch.split_seq(n_ubatch);
-            } else {
-                // recurrent model architectures are easier to implement
-                // with equal-length sequences
-                ubatch = lctx.sbatch.split_equal(n_ubatch);
-            }
+    return 0;
+}
+
+static int llama_prepare_ubatch(
+        llama_context          & lctx,
+        llama_kv_slot_restorer & kv_slot_restorer,
+        llama_ubatch           & ubatch,
+        const uint32_t           n_outputs,
+        const uint32_t           n_tokens_all) {
+    GGML_ASSERT(lctx.sbatch.n_tokens > 0);
+
+    auto       & kv_self = lctx.kv_self;
+    const auto & cparams = lctx.cparams;
+    const auto & hparams = lctx.model.hparams;
+
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+    if (lctx.kv_self.recurrent) {
+        if (embd_pooled) {
+            // Pooled embeddings cannot be split across ubatches (yet)
+            ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
         } else {
-            ubatch = lctx.sbatch.split_simple(n_ubatch);
+            // recurrent model architectures are easier to implement
+            // with equal-length sequences
+            ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
         }
-        const uint32_t n_tokens = ubatch.n_tokens;
+    } else {
+        ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
+    }
 
-        // count the outputs in this u_batch
-        {
-            int32_t n_outputs_new = 0;
+    // count the outputs in this u_batch
+    {
+        int32_t n_outputs_new = 0;
 
-            if (n_outputs == n_tokens_all) {
-                n_outputs_new = n_tokens;
-            } else {
-                GGML_ASSERT(ubatch.output);
-                for (uint32_t i = 0; i < n_tokens; i++) {
-                    n_outputs_new += (int32_t) (ubatch.output[i] != 0);
-                }
+        if (n_outputs == n_tokens_all) {
+            n_outputs_new = ubatch.n_tokens;
+        } else {
+            GGML_ASSERT(ubatch.output);
+            for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
+                n_outputs_new += int32_t(ubatch.output[i] != 0);
             }
+        }
+
+        // needs to happen before the graph is built
+        lctx.n_outputs = n_outputs_new;
+    }
+
+    // non-causal masks do not use the KV cache
+    if (hparams.causal_attn) {
+        llama_kv_cache_update(&lctx);
 
-            // needs to happen before the graph is built
-            lctx.n_outputs = n_outputs_new;
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
+            kv_self.head = 0;
         }
 
-        int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
-        ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+        const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        if (!slot) {
+            return 1;
+        }
+        kv_slot_restorer.save(slot);
 
-        GGML_ASSERT(n_threads > 0);
+        if (!kv_self.recurrent) {
+            // a heuristic, to avoid attending the full cache if it is not yet utilized
+            // after enough generations, the benefit from this heuristic disappears
+            // if we start defragmenting the cache, the benefit from this will be more important
+            const uint32_t pad = llama_kv_cache_get_padding(cparams);
+            kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
+            //kv_self.n = llama_kv_cache_cell_max(kv_self);
+        }
+    }
 
-        // non-causal masks do not use the KV cache
-        if (hparams.causal_attn) {
-            llama_kv_cache_update(&lctx);
+    return 0;
+}
 
-            // if we have enough unused cells before the current head ->
-            //   better to start searching from the beginning of the cache, hoping to fill it
-            if (kv_self.head > kv_self.used + 2*n_tokens) {
-                kv_self.head = 0;
-            }
+// decode a batch of tokens by evaluating the transformer
+// in case of unsuccessful decoding (error or warning),
+// the kv_cache state will be returned to its original state
+// (for non-recurrent models) or cleaned (for recurrent models)
+//
+//   - lctx:      llama context
+//   - inp_batch: batch to evaluate
+//
+// return 0 on success
+// return positive int on warning
+// return negative int on error
+//
+static int llama_decode_impl(
+         llama_context & lctx,
+           llama_batch   inp_batch) {
 
-            const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-            if (!slot) {
-                return 1;
-            }
-            kv_slot_restorer.save(slot);
+    lctx.is_encoding = false;
 
-            if (!kv_self.recurrent) {
-                // a heuristic, to avoid attending the full cache if it is not yet utilized
-                // after enough generations, the benefit from this heuristic disappears
-                // if we start defragmenting the cache, the benefit from this will be more important
-                const uint32_t pad = llama_kv_cache_get_padding(cparams);
-                kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
-                //kv_self.n = llama_kv_cache_cell_max(kv_self);
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
+        return -1;
+    }
+
+    // temporarily allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
+    const llama_batch & batch = batch_allocr.batch;
+
+    const auto & model   = lctx.model;
+    const auto & vocab   = model.vocab;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    if (lctx.t_compute_start_us == 0) {
+        lctx.t_compute_start_us = ggml_time_us();
+    }
+    auto & kv_self = lctx.kv_self;
+    llama_kv_slot_restorer kv_slot_restorer(kv_self);
+
+    const int64_t n_embd  = hparams.n_embd;
+    const int64_t n_vocab = vocab.n_tokens();
+
+    uint32_t n_outputs = 0;
+    uint32_t n_outputs_prev = 0;
+
+    {
+        const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
+        if (ret != 0) {
+            return ret;
+        }
+    }
+
+    while (lctx.sbatch.n_tokens > 0) {
+        llama_ubatch ubatch;
+        {
+            const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
+            if (ret != 0) {
+                return ret;
             }
         }
 
+        const int         n_threads  = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+        ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool   : lctx.threadpool_batch;
+
+        GGML_ASSERT(n_threads > 0);
+
         //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
         ggml_backend_sched_reset(lctx.sched.get());
@@ -8640,7 +8677,7 @@ static int llama_decode_impl(
 
         // update the kv ring buffer
         {
-            kv_self.head += n_tokens;
+            kv_self.head += ubatch.n_tokens;
 
             // Ensure kv cache head points to a valid index.
             if (kv_self.head >= kv_self.size) {