]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
imatrix : handle partial entries (#7833)
authorGeorgi Gerganov <redacted>
Sun, 9 Jun 2024 17:19:35 +0000 (20:19 +0300)
committerGitHub <redacted>
Sun, 9 Jun 2024 17:19:35 +0000 (20:19 +0300)
examples/imatrix/imatrix.cpp

index e18f495630616470947203082e9c476ddb8aaa52..574f5ed9c2e657c3d7884bd4a61d73966de6e271 100644 (file)
@@ -218,20 +218,64 @@ void IMatrixCollector::save_imatrix(int ncall) const {
         fname += std::to_string(ncall);
     }
 
+    // avoid writing imatrix entries that do not have full data
+    // this can happen with MoE models where some of the experts end up not being exercised by the provided training data
+
+    int n_entries = 0;
+    std::vector<std::string> to_store;
+
+    bool is_first = true; // for printing
+    for (const auto & kv : m_stats) {
+        const int n_all = kv.second.counts.size();
+
+        if (n_all == 0) {
+            continue;
+        }
+
+        int n_zeros = 0;
+        for (const int c : kv.second.counts) {
+            if (c == 0) {
+                n_zeros++;
+            }
+        }
+
+        if (n_zeros != 0 && is_first) {
+            fprintf(stderr, "\n");
+            is_first = false;
+        }
+
+        if (n_zeros == n_all) {
+            fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
+            continue;
+        }
+
+        if (n_zeros > 0) {
+            fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
+            continue;
+        }
+
+        n_entries++;
+        to_store.push_back(kv.first);
+    }
+
+    if (to_store.size() < m_stats.size()) {
+        fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
+    }
+
     std::ofstream out(fname, std::ios::binary);
-    int n_entries = m_stats.size();
     out.write((const char *) &n_entries, sizeof(n_entries));
-    for (const auto & p : m_stats) {
-        int len = p.first.size();
+    for (const auto & name : to_store) {
+        const auto & stat = m_stats.at(name);
+        int len = name.size();
         out.write((const char *) &len, sizeof(len));
-        out.write(p.first.c_str(), len);
-        out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
-        int nval = p.second.values.size();
+        out.write(name.c_str(), len);
+        out.write((const char *) &stat.ncall, sizeof(stat.ncall));
+        int nval = stat.values.size();
         out.write((const char *) &nval, sizeof(nval));
         if (nval > 0) {
             std::vector<float> tmp(nval);
             for (int i = 0; i < nval; i++) {
-                tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
+                tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall);
             }
             out.write((const char*)tmp.data(), nval*sizeof(float));
         }