]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
perplexity : fix MSVC build after #5020 (#5043)
authorJared Van Bortel <redacted>
Sat, 20 Jan 2024 15:08:08 +0000 (10:08 -0500)
committerGitHub <redacted>
Sat, 20 Jan 2024 15:08:08 +0000 (17:08 +0200)
* perplexity : fix MSVC build after #5020

* try a differerent fix

examples/perplexity/perplexity.cpp

index b0732019082497631331385906f21dfae70b7aba..f91f5795a9851434723c9ea53886353441be7a33 100644 (file)
@@ -458,23 +458,24 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
     return true;
 }
 
+#define K_TOKEN_CHUNK 4
+
 static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
         const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
-    constexpr int k_token_chunk = 4;
     if (eval_results.size() != eval_pairs.size()) {
         eval_results.resize(eval_pairs.size());
     }
     if (eval_pairs.empty()) return;
 
-    size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size());
+    size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size());
 
     std::atomic<int> counter(0);
     auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
-        float local_logprobs[k_token_chunk];
+        float local_logprobs[K_TOKEN_CHUNK];
         while (true) {
-            size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed);
+            size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed);
             if (first >= eval_results.size()) break;
-            size_t last = std::min(first + k_token_chunk, eval_results.size());
+            size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size());
             for (size_t i = first; i < last; ++i) {
                 auto logits = batch_logits + eval_pairs[i].first * n_vocab;
                 float max_logit = logits[0];
@@ -497,7 +498,6 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
     for (size_t it = 0; it < max_threads; ++it) {
         workers[it].join();
     }
-
 }
 
 static void hellaswag_score(llama_context * ctx, const gpt_params & params) {