]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fixed save_imatrix to match old behaviour for MoE (#7099)
authorjukofyork <redacted>
Wed, 8 May 2024 00:24:16 +0000 (01:24 +0100)
committerGitHub <redacted>
Wed, 8 May 2024 00:24:16 +0000 (02:24 +0200)
* Fixed save_imatrix to match old behaviour for MoE

This fix is simple and clear, but unnecessarily doubles the memory overhead..

* Fixed missing idx variable

* Unconditionally increment ncall

Co-authored-by: slaren <redacted>
* Fixed 2 bugs in save_imatrix()

- Fixed segfault bug because the counts vector needed to be created.
- Fixed pre-existing bug didn't actually add to the counts for "--combine" option.

* ncall needs summing too

* Trailing whitespace

---------

Co-authored-by: slaren <redacted>
examples/imatrix/imatrix.cpp

index 71e7a727f1943224e93052badfc2ec293bc255b1..82b19fc4f3bae0ff2b5a83b592d289b685164342 100644 (file)
@@ -19,6 +19,7 @@
 
 struct Stats {
     std::vector<float> values;
+    std::vector<int> counts;
     int ncall = 0;
 };
 
@@ -121,12 +122,10 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
         auto & e = m_stats[wname];
 
         ++e.ncall;
-        // NOTE: since we select top-k experts, the number of calls for the expert tensors will be k times larger
-        //       using the following line, we can correct for that if needed by replacing the line above with:
-        //if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
 
         if (e.values.empty()) {
             e.values.resize(src1->ne[0]*n_as, 0);
+            e.counts.resize(src1->ne[0]*n_as, 0);
         }
         else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
             fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
@@ -153,6 +152,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
 
                     for (int j = 0; j < (int)src1->ne[0]; ++j) {
                         e.values[e_start + j] += x[j]*x[j];
+                        e.counts[e_start + j]++;
                     }
                 }
             }
@@ -170,6 +170,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
         auto& e = m_stats[wname];
         if (e.values.empty()) {
             e.values.resize(src1->ne[0], 0);
+            e.counts.resize(src1->ne[0], 0);
         }
         else if (e.values.size() != (size_t)src1->ne[0]) {
             fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
@@ -183,6 +184,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
             const float * x = data + row * src1->ne[0];
             for (int j = 0; j < (int)src1->ne[0]; ++j) {
                 e.values[j] += x[j]*x[j];
+                e.counts[j]++;
             }
         }
         if (e.ncall > m_last_call) {
@@ -222,7 +224,13 @@ void IMatrixCollector::save_imatrix(const char * fname, const char * dataset) co
         out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
         int nval = p.second.values.size();
         out.write((const char *) &nval, sizeof(nval));
-        if (nval > 0) out.write((const char *) p.second.values.data(), nval * sizeof(float));
+        if (nval > 0) {
+            std::vector<float> tmp(nval);
+            for (int i = 0; i < nval; i++) {
+                tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
+            }
+            out.write((const char*)tmp.data(), nval*sizeof(float));
+        }
     }
 
     // Write the number of call the matrix was computed with
@@ -270,14 +278,28 @@ bool IMatrixCollector::load_imatrix(const char * imatrix_file, std::unordered_ma
             imatrix_data = {};
             return false;
         }
-        e.values.resize(nval);
-        in.read((char*)e.values.data(), nval*sizeof(float));
+
+        // When re-called from load_imatrix() with add set, this will already be created.
+        if (e.values.empty()) {
+            e.values.resize(nval, 0);
+            e.counts.resize(nval, 0);
+        }
+
+        std::vector<float> tmp(nval);
+        in.read((char*)tmp.data(), nval*sizeof(float));
         if (in.fail()) {
             printf("%s: failed reading data for entry %d\n",__func__,i);
             imatrix_data = {};
             return false;
         }
-        e.ncall = ncall;
+
+        // Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
+        for (int i = 0; i < nval; i++) {
+            e.values[i] += tmp[i];
+            e.counts[i] += ncall;
+        }
+        e.ncall += ncall;
+
     }
     return true;
 }