]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : correct hparams comparison (#3446)
authorl3utterfly <redacted>
Fri, 6 Oct 2023 10:47:59 +0000 (18:47 +0800)
committerGitHub <redacted>
Fri, 6 Oct 2023 10:47:59 +0000 (13:47 +0300)
* fixed floating point comparison issues

* updated implementation for hparam comparison to handle inf and NaN

* fixed code review comments

* minor simplification

* rename is_float_eq -> is_float_close

---------

Co-authored-by: Cebtenzzre <redacted>
llama.cpp

index 08d6c162a5d7ce47aae104f4d60ea6bf2e0fcf49..56413f3a241ebf7389567140c8d1cbcea28c763f 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -125,6 +125,27 @@ static void replace_all(std::string & s, const std::string & search, const std::
     }
     s = std::move(result);
 }
+
+static bool is_float_close(float a, float b, float abs_tol) {
+    // Check for non-negative tolerance
+    if (abs_tol < 0.0) {
+        throw std::invalid_argument("Tolerance must be non-negative");
+    }
+
+    // Exact equality check
+    if (a == b) {
+        return true;
+    }
+
+    // Check for infinities
+    if (std::isinf(a) || std::isinf(b)) {
+        return false;
+    }
+
+    // Regular comparison using the provided absolute tolerance
+    return std::fabs(b - a) <= abs_tol;
+}
+
 #ifdef GGML_USE_CPU_HBM
 #include <hbwmalloc.h>
 #endif
@@ -969,7 +990,24 @@ struct llama_hparams {
     float rope_freq_scale_train;
 
     bool operator!=(const llama_hparams & other) const {
-        return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
+        if (this->vocab_only != other.vocab_only) return true;
+        if (this->n_vocab != other.n_vocab) return true;
+        if (this->n_ctx_train != other.n_ctx_train) return true;
+        if (this->n_embd != other.n_embd) return true;
+        if (this->n_head != other.n_head) return true;
+        if (this->n_head_kv != other.n_head_kv) return true;
+        if (this->n_layer != other.n_layer) return true;
+        if (this->n_rot != other.n_rot) return true;
+        if (this->n_ff != other.n_ff) return true;
+
+        const float EPSILON = 1e-9;
+
+        if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
+        if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
+        if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
+        if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
+
+        return false;
     }
 
     uint32_t n_gqa() const {