]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Another bucket sort (#5109)
authorKawrakow <redacted>
Fri, 26 Jan 2024 07:14:39 +0000 (09:14 +0200)
committerGitHub <redacted>
Fri, 26 Jan 2024 07:14:39 +0000 (09:14 +0200)
* Initial bucket sort

* Bucket sort: slightly better version

* Bucket sort: another minor improvement

---------

Co-authored-by: Iwan Kawrakow <redacted>
llama.cpp

index 823d42d7fb50ab087192057b4ffb93a077abc474..b03b67e169955cc684148155c566d6db3f069ba2 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -7956,10 +7956,57 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
         auto comp = [](const llama_token_data & a, const llama_token_data & b) {
             return a.logit > b.logit;
         };
-        if (k == (int) candidates->size) {
-            std::sort(candidates->data, candidates->data + candidates->size, comp);
-        } else {
+        if (k <= 128) {
             std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
+        } else {
+            constexpr int   nbuckets     = 128;
+            constexpr float bucket_low   = -10.0f;
+            constexpr float bucket_high  =  10.0f;
+            constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
+            constexpr float bucker_inter = -bucket_low * bucket_scale;
+
+            std::vector<int> bucket_idx(candidates->size);
+            std::vector<int> histo(nbuckets, 0);
+
+            for (int i = 0; i < (int)candidates->size; ++i) {
+                const float val = candidates->data[i].logit;
+                int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
+                ib = std::max(0, std::min(nbuckets-1, ib));
+                bucket_idx[i] = ib;
+                ++histo[ib];
+            }
+            int nhave = 0;
+            int ib = nbuckets - 1;
+            for ( ; ib >= 0; --ib) {
+                nhave += histo[ib];
+                if (nhave >= k) break;
+            }
+            std::vector<llama_token_data> tmp_tokens(nhave);
+            auto ptr = tmp_tokens.data();
+            std::vector<llama_token_data*> bucket_ptrs;
+            bucket_ptrs.reserve(nbuckets - ib);
+            for (int j = nbuckets - 1; j >= ib; --j) {
+                bucket_ptrs.push_back(ptr);
+                ptr += histo[j];
+            }
+            for (int i = 0; i < (int)candidates->size; ++i) {
+                int j = bucket_idx[i];
+                if (j >= ib) {
+                    *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
+                }
+            }
+
+            ptr = tmp_tokens.data();
+            int ndone = 0;
+            for (int j = nbuckets-1; j > ib; --j) {
+                std::sort(ptr, ptr + histo[j], comp);
+                ptr += histo[j];
+                ndone += histo[j];
+            }
+            std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
+
+            std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
+
         }
         candidates->sorted = true;
     }