]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : refactor sampling_info to use buffer_view template (#19368)
authorDaniel Bevenius <redacted>
Wed, 11 Feb 2026 04:38:13 +0000 (05:38 +0100)
committerGitHub <redacted>
Wed, 11 Feb 2026 04:38:13 +0000 (05:38 +0100)
* llama : refactor sampling_info to use buffer_view template

This commit updates the sampling_info struct in llama-context to use a
buffer_view template for the logits, probs, sampled tokens, and
candidates buffers.

The motivation for this is to simplify the code, improve type safety
and readability.

src/llama-context.cpp
src/llama-context.h
src/llama-impl.h

index 1a35ce369cbd036bd8c55368277aa74bc66337ad..6b43ca19267fead06af5525ad9bd4f7e8001c4ae 100644 (file)
@@ -677,7 +677,7 @@ enum llama_pooling_type llama_context::pooling_type() const {
 float * llama_context::get_logits() {
     output_reorder();
 
-    return logits;
+    return logits.data;
 }
 
 int64_t llama_context::output_resolve_row(int32_t i) const {
@@ -715,7 +715,7 @@ float * llama_context::get_logits_ith(int32_t i) {
     output_reorder();
 
     try {
-        if (logits == nullptr) {
+        if (logits.data == nullptr) {
             throw std::runtime_error("no logits");
         }
 
@@ -739,7 +739,7 @@ float * llama_context::get_logits_ith(int32_t i) {
             throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
         }
 
-        return logits + j*model.vocab.n_tokens();
+        return logits.data + j*model.vocab.n_tokens();
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
@@ -753,11 +753,11 @@ float * llama_context::get_logits_ith(int32_t i) {
 float * llama_context::get_embeddings() {
     output_reorder();
 
-    return embd;
+    return embd.data;
 }
 
 llama_token * llama_context::get_sampled_tokens()  const{
-    return sampling.sampled;
+    return sampling.sampled.data;
 }
 
 float * llama_context::get_embeddings_ith(int32_t i) {
@@ -766,7 +766,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
     output_reorder();
 
     try {
-        if (embd == nullptr) {
+        if (embd.data == nullptr) {
             throw std::runtime_error("no embeddings");
         }
 
@@ -791,7 +791,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
         }
 
         const uint32_t n_embd_out = model.hparams.n_embd_out();
-        return embd + j*n_embd_out;
+        return embd.data + j*n_embd_out;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
@@ -814,14 +814,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
 llama_token llama_context::get_sampled_token_ith(int32_t idx) {
     output_reorder();
 
-    if (sampling.sampled == nullptr) {
+    if (!sampling.sampled.has_data()) {
         return LLAMA_TOKEN_NULL;
     }
 
     try {
         const int64_t row = output_resolve_row(idx);
-        GGML_ASSERT(row < (int64_t) sampling.sampled_size);
-        return sampling.sampled[row];
+        GGML_ASSERT(row < (int64_t) sampling.sampled.size);
+        return sampling.sampled.data[row];
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
         return LLAMA_TOKEN_NULL;
@@ -831,7 +831,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
 float * llama_context::get_sampled_probs_ith(int32_t idx) {
     output_reorder();
 
-    if (sampling.probs == nullptr) {
+    if (!sampling.probs.has_data()) {
         return nullptr;
     }
 
@@ -840,7 +840,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
         if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
             return nullptr;
         }
-        return sampling.probs + row*model.vocab.n_tokens();
+        return sampling.probs.data + row*model.vocab.n_tokens();
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
         return nullptr;
@@ -850,7 +850,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
 float * llama_context::get_sampled_logits_ith(int32_t idx) {
     output_reorder();
 
-    if (sampling.logits == nullptr) {
+    if (!sampling.logits.has_data()) {
         return nullptr;
     }
 
@@ -859,7 +859,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) {
         if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
             return nullptr;
         }
-        return sampling.logits + row*model.vocab.n_tokens();
+        return sampling.logits.data + row*model.vocab.n_tokens();
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
         return nullptr;
@@ -871,10 +871,10 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
 
     try {
         const int64_t row = output_resolve_row(idx);
-        if (sampling.candidates != nullptr &&
+        if (sampling.candidates.has_data() &&
             (size_t) row < sampling.candidates_count.size() &&
             sampling.candidates_count[row] > 0) {
-            return sampling.candidates + row*model.vocab.n_tokens();
+            return sampling.candidates.data + row*model.vocab.n_tokens();
         }
     } catch (const std::exception & err) {
         // fallback to full vocab list
@@ -886,7 +886,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
 size_t llama_context::get_sampled_candidates_count(int32_t idx) {
     output_reorder();
 
-    if (sampling.candidates == nullptr) {
+    if (!sampling.candidates.has_data()) {
         return 0;
     }
 
@@ -905,7 +905,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
 size_t llama_context::get_sampled_logits_count(int32_t idx) {
     output_reorder();
 
-    if (sampling.logits == nullptr) {
+    if (!sampling.logits.has_data()) {
         return model.vocab.n_tokens();
     }
 
@@ -924,7 +924,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
 size_t llama_context::get_sampled_probs_count(int32_t idx) {
     output_reorder();
 
-    if (sampling.probs == nullptr) {
+    if (!sampling.probs.has_data()) {
         return 0;
     }
 
@@ -1254,16 +1254,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
     auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
 
     // extract logits
-    if (logits && t_logits) {
+    if (logits.data && t_logits) {
         ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
         GGML_ASSERT(backend_res != nullptr);
-        GGML_ASSERT(logits != nullptr);
+        GGML_ASSERT(logits.data != nullptr);
 
-        ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
+        ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float));
     }
 
     // extract embeddings
-    if (embd && t_embd) {
+    if (embd.data && t_embd) {
         ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
         GGML_ASSERT(backend_embd != nullptr);
 
@@ -1271,11 +1271,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
             case LLAMA_POOLING_TYPE_NONE:
                 {
                     // extract token embeddings
-                    GGML_ASSERT(embd != nullptr);
+                    GGML_ASSERT(embd.data != nullptr);
                     const uint32_t n_embd_out = hparams.n_embd_out();
 
-                    GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
-                    ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
+                    GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size);
+                    ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float));
                 } break;
             case LLAMA_POOLING_TYPE_MEAN:
             case LLAMA_POOLING_TYPE_CLS:
@@ -1323,7 +1323,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
         cross.n_embd = t_embd->ne[0];
         cross.n_enc  = t_embd->ne[1];
         cross.v_embd.resize(cross.n_embd*cross.n_enc);
-        memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
+        memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd));
 
         const auto & batch = balloc->get_batch();
 
@@ -1363,11 +1363,10 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
 
 static void copy_tensor_async_ints(
     const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
-    llama_token * sampled,
-    size_t sampled_size,
+    const buffer_view<llama_token> & sampled,
     const std::map<llama_seq_id, uint32_t> & seq_to_row,
     ggml_backend_sched_t sched) {
-    if (sampled == nullptr) {
+    if (!sampled.has_data()) {
         return;
     }
 
@@ -1378,23 +1377,23 @@ static void copy_tensor_async_ints(
         }
 
         const uint32_t row = it->second;
-        GGML_ASSERT(row < sampled_size);
+        GGML_ASSERT(row < sampled.size);
 
         GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
 
         ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
-        ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
+        ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
     }
 }
 
 static void copy_tensor_async_floats(
     const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
-    float * dst,
+    const buffer_view<float> & dst,
     size_t stride,
     std::vector<uint32_t> & counts,
     const std::map<llama_seq_id, uint32_t> & seq_to_row,
     ggml_backend_sched_t sched) {
-    if (dst == nullptr) {
+    if (!dst.has_data()) {
         return;
     }
 
@@ -1410,7 +1409,7 @@ static void copy_tensor_async_floats(
         GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
 
         ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
-        float * row_ptr = dst + (size_t) row * stride;
+        float * row_ptr = dst.data + (size_t) row * stride;
         ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
 
         // Update the actual number of logits/probabilities that were written for this row.
@@ -1420,12 +1419,12 @@ static void copy_tensor_async_floats(
 
 static void copy_tensor_async_candidates(
     const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
-    llama_token * dst,
+    const buffer_view<llama_token> & dst,
     size_t stride,
     std::vector<uint32_t> & counts,
     const std::map<llama_seq_id, uint32_t> & seq_to_row,
     ggml_backend_sched_t sched) {
-    if (dst == nullptr) {
+    if (!dst.has_data()) {
         return;
     }
 
@@ -1441,7 +1440,7 @@ static void copy_tensor_async_candidates(
         GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
 
         ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
-        llama_token * row_ptr = dst + (size_t) row * stride;
+        llama_token * row_ptr = dst.data + (size_t) row * stride;
         ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
 
         // Update the actual number of candidates that were written.
@@ -1671,22 +1670,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
         }
 
         // extract logits
-        if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
+        if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
             ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
             GGML_ASSERT(backend_res != nullptr);
-            GGML_ASSERT(logits != nullptr);
+            GGML_ASSERT(logits.data != nullptr);
 
-            float * logits_out = logits + n_outputs_prev*n_vocab;
+            float * logits_out = logits.data + n_outputs_prev*n_vocab;
 
             if (n_outputs) {
                 GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
-                GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
+                GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size);
                 ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
             }
         }
 
         // extract embeddings
-        if (embd && t_embd && n_outputs > 0) {
+        if (embd.data && t_embd && n_outputs > 0) {
             ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
             GGML_ASSERT(backend_embd != nullptr);
 
@@ -1694,13 +1693,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
                 case LLAMA_POOLING_TYPE_NONE:
                     {
                         // extract token embeddings
-                        GGML_ASSERT(embd != nullptr);
+                        GGML_ASSERT(embd.data != nullptr);
                         const uint32_t n_embd_out = hparams.n_embd_out();
-                        float * embd_out = embd + n_outputs_prev*n_embd_out;
+                        float * embd_out = embd.data + n_outputs_prev*n_embd_out;
 
                         if (n_outputs) {
                             GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
-                            GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
+                            GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size);
                             ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
                         }
                     } break;
@@ -1747,7 +1746,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
             const auto stride = n_vocab;
 
             // async copy the sampling data from the backend to the host
-            copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
+            copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
 
             copy_tensor_async_floats    (res->t_sampled_logits, sampling.logits,     stride, sampling.logits_count,     seq_to_output_row, sched.get());
             copy_tensor_async_floats    (res->t_sampled_probs,  sampling.probs,      stride, sampling.probs_count,      seq_to_output_row, sched.get());
@@ -1841,19 +1840,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
     size_t backend_float_count = 0;
     size_t backend_token_count = 0;
 
-    logits_size = has_logits ? n_vocab*n_outputs_max : 0;
-    embd_size   = has_embd ? n_embd_out*n_outputs_max : 0;
+    logits.size = has_logits ? n_vocab*n_outputs_max : 0;
+    embd.size   = has_embd ? n_embd_out*n_outputs_max : 0;
 
     // Allocate backend sampling output buffers if there are backend samplers configured.
     const bool has_sampling = !sampling.samplers.empty();
     if (has_sampling) {
-        sampling.logits_size     = n_vocab*n_outputs_max;
-        sampling.probs_size      = n_vocab*n_outputs_max;
-        sampling.sampled_size    =         n_outputs_max;
-        sampling.candidates_size = n_vocab*n_outputs_max;
-
-        backend_float_count = sampling.logits_size  + sampling.probs_size;
-        backend_token_count = sampling.sampled_size + sampling.candidates_size;
+        backend_float_count = 2 * n_vocab * n_outputs_max;      // logits + probs
+        backend_token_count = (1 + n_vocab) * n_outputs_max;    // sampled + candidates
     }
 
     if (output_ids.empty()) {
@@ -1863,7 +1857,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
 
     const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
     const size_t new_size  =
-        (logits_size + embd_size + backend_float_count) * sizeof(float) +
+        (logits.size + embd.size + backend_float_count) * sizeof(float) +
         (                          backend_token_count) * sizeof(llama_token);
 
     // alloc only when more than the current capacity is required
@@ -1878,8 +1872,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
 
             // TODO: not needed?
             buf_output = nullptr;
-            logits = nullptr;
-            embd = nullptr;
+            logits.data = nullptr;
+            embd.data = nullptr;
         }
 
         auto * buft = ggml_backend_cpu_buffer_type();
@@ -1898,35 +1892,32 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
 
     float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
 
-    logits = nullptr;
-    embd   = nullptr;
-
     size_t offset = 0;
     uint8_t * base = (uint8_t *) output_base;
 
-    logits = has_logits ? output_base : nullptr;
-    offset += logits_size * sizeof(float);
+    logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0};
+    offset += logits.size * sizeof(float);
 
-    embd = has_embd ? (float *) (base + offset) : nullptr;
-    offset += embd_size * sizeof(float);
+    embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
+    offset += embd.size * sizeof(float);
 
-    sampling.logits     = nullptr;
-    sampling.probs      = nullptr;
-    sampling.sampled    = nullptr;
-    sampling.candidates = nullptr;
+    sampling.logits     = {nullptr, 0};
+    sampling.probs      = {nullptr, 0};
+    sampling.sampled    = {nullptr, 0};
+    sampling.candidates = {nullptr, 0};
 
     if (has_sampling) {
-        sampling.logits = (float *) (base + offset);
-        offset += sampling.logits_size * sizeof(float);
+        sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
+        offset += sampling.logits.size * sizeof(float);
 
-        sampling.probs = (float *) (base + offset);
-        offset += sampling.probs_size * sizeof(float);
+        sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
+        offset += sampling.probs.size * sizeof(float);
 
-        sampling.sampled = (llama_token *) (base + offset);
-        offset += sampling.sampled_size * sizeof(llama_token);
+        sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
+        offset += sampling.sampled.size * sizeof(llama_token);
 
-        sampling.candidates = (llama_token *) (base + offset);
-        offset += sampling.candidates_size * sizeof(llama_token);
+        sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
+        offset += sampling.candidates.size * sizeof(llama_token);
 
         // The count vectors keep track of the actual number of logits/probs/candidates
         // copied from the backend for each output row.
@@ -1939,7 +1930,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
         std::fill(sampling.probs_count.begin(),      sampling.probs_count.end(),      0);
         std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
 
-        std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
+        std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
     }
 
     // set all ids as invalid (negative)
@@ -1958,38 +1949,38 @@ void llama_context::output_reorder() {
         const uint64_t i0 = output_swaps[s].i0;
         const uint64_t i1 = output_swaps[s].i1;
 
-        if (logits_size > 0) {
+        if (logits.size > 0) {
             for (uint64_t k = 0; k < n_vocab; k++) {
-                std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
+                std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]);
             }
         }
 
-        if (embd_size > 0) {
+        if (embd.size > 0) {
             for (uint64_t k = 0; k < n_embd; k++) {
-                std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
+                std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]);
             }
         }
 
-        if (sampling.logits && sampling.logits_size > 0) {
+        if (sampling.logits.has_data()) {
             for (uint64_t k = 0; k < n_vocab; ++k) {
-                std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
+                std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
             }
         }
 
-        if (sampling.probs && sampling.probs_size > 0) {
+        if (sampling.probs.has_data()) {
             for (uint64_t k = 0; k < n_vocab; ++k) {
-                std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
+                std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
             }
         }
 
-        if (sampling.candidates && sampling.candidates_size > 0) {
+        if (sampling.candidates.has_data()) {
             for (uint64_t k = 0; k < n_vocab; ++k) {
-                std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
+                std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
             }
         }
 
-        if (sampling.sampled && sampling.sampled_size > 0) {
-            std::swap(sampling.sampled[i0], sampling.sampled[i1]);
+        if (sampling.sampled.has_data()) {
+            std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
         }
 
         if (!sampling.logits_count.empty()) {
@@ -2533,12 +2524,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
     {
         LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
 
-        const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
+        const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
 
         io.write(&logits_size, sizeof(logits_size));
 
         if (logits_size) {
-            io.write(logits, logits_size * sizeof(float));
+            io.write(logits.data, logits_size * sizeof(float));
         }
     }
 
@@ -2546,12 +2537,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
     {
         LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
 
-        const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
+        const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
 
         io.write(&embd_size, sizeof(embd_size));
 
         if (embd_size) {
-            io.write(embd, embd_size * sizeof(float));
+            io.write(embd.data, embd_size * sizeof(float));
         }
     }
 
@@ -2619,12 +2610,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
         uint64_t logits_size;
         io.read_to(&logits_size, sizeof(logits_size));
 
-        if (this->logits_size < logits_size) {
+        if (this->logits.size < logits_size) {
             throw std::runtime_error("logits buffer too small");
         }
 
         if (logits_size) {
-            io.read_to(this->logits, logits_size * sizeof(float));
+            io.read_to(this->logits.data, logits_size * sizeof(float));
         }
     }
 
@@ -2635,12 +2626,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
         uint64_t embd_size;
         io.read_to(&embd_size, sizeof(embd_size));
 
-        if (this->embd_size < embd_size) {
+        if (this->embd.size < embd_size) {
             throw std::runtime_error("embeddings buffer too small");
         }
 
         if (embd_size) {
-            io.read_to(this->embd, embd_size * sizeof(float));
+            io.read_to(this->embd.data, embd_size * sizeof(float));
         }
     }
 
index 8e71cdd1dc5346797126322185babe5df883c394..d9951175744809de45d8f7dc756caad5148b6211 100644 (file)
@@ -4,6 +4,7 @@
 #include "llama-cparams.h"
 #include "llama-graph.h"
 #include "llama-adapter.h"
+#include "llama-impl.h"
 
 #include "ggml-cpp.h"
 #include "ggml-opt.h"
@@ -269,29 +270,19 @@ private:
     std::unique_ptr<llama_memory_i> memory;
 
     // decode output (2-dimensional array: [n_outputs][n_vocab])
-    size_t  logits_size = 0; // capacity (of floats) for logits
-    float * logits      = nullptr;
+    struct buffer_view<float>  logits = {nullptr, 0};
 
     // embeddings output (2-dimensional array: [n_outputs][n_embd])
     // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
-    size_t  embd_size = 0; // capacity (of floats) for embeddings
-    float * embd      = nullptr;
+    struct buffer_view<float>  embd = {nullptr, 0};
 
-    // TODO: simplify
     struct sampling_info {
         std::map<llama_seq_id, llama_sampler *> samplers;
 
-        float       * logits      = nullptr;
-        size_t        logits_size = 0;
-
-        llama_token * sampled      = nullptr;
-        size_t        sampled_size = 0;
-
-        float       * probs        = nullptr;
-        size_t        probs_size   = 0;
-
-        llama_token * candidates   = nullptr;
-        size_t        candidates_size = 0;
+        struct buffer_view<float>       logits     = {nullptr, 0};
+        struct buffer_view<llama_token> sampled    = {nullptr, 0};
+        struct buffer_view<float>       probs      = {nullptr, 0};
+        struct buffer_view<llama_token> candidates = {nullptr, 0};
 
         std::vector<uint32_t> logits_count;
         std::vector<uint32_t> probs_count;
index c3391e79f51662cd9efd6805e70c2c7b0ca2ccd8..dfd9fee9f445e1e56ce7f0e3aecead5bf078c8aa 100644 (file)
@@ -49,6 +49,16 @@ struct time_meas {
     int64_t & t_acc;
 };
 
+template <typename T>
+struct buffer_view {
+    T * data;
+    size_t size = 0;
+
+    bool has_data() const {
+        return data && size > 0;
+    }
+};
+
 void replace_all(std::string & s, const std::string & search, const std::string & replace);
 
 // TODO: rename to llama_format ?