]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
tests : sync test-grad0 from ggml
authorGeorgi Gerganov <redacted>
Sat, 24 Jun 2023 16:40:18 +0000 (19:40 +0300)
committerGeorgi Gerganov <redacted>
Sat, 24 Jun 2023 16:40:18 +0000 (19:40 +0300)
tests/test-grad0.c

index c8c2c0f717e320ff699190134d559b7420e6feac..b5a499c1db57e4dae69d92df31f5b4f32498b16a 100644 (file)
@@ -1,3 +1,4 @@
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
 #include "ggml.h"
 
 #include <math.h>
@@ -5,6 +6,10 @@
 #include <stdlib.h>
 #include <assert.h>
 
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
 #define MAX_NARGS 3
 
 #undef MIN
@@ -197,8 +202,23 @@ bool check_gradient(
         float max_error_abs,
         float max_error_rel) {
 
+    static int n_threads = -1;
+    if (n_threads < 0) {
+        n_threads = GGML_DEFAULT_N_THREADS;
+
+        const char *env = getenv("GGML_N_THREADS");
+        if (env) {
+            n_threads = atoi(env);
+        }
+
+        printf("GGML_N_THREADS = %d\n", n_threads);
+    }
+
     struct ggml_cgraph gf = ggml_build_forward (f);
+    gf.n_threads = n_threads;
+
     struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+    gb.n_threads = n_threads;
 
     ggml_graph_compute(ctx0, &gf);
     ggml_graph_reset  (&gf);