]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
tests : allow to set threads to test-grad0
authorGeorgi Gerganov <redacted>
Sat, 24 Jun 2023 16:39:32 +0000 (19:39 +0300)
committerGeorgi Gerganov <redacted>
Sat, 24 Jun 2023 16:39:32 +0000 (19:39 +0300)
.github/workflows/ci.yml
tests/test-grad0.c

index 070783c9437b4e046b2ace3c1cfdb82d8e57df82..8332deff48e86026c34d40e62e450c306d050694 100644 (file)
@@ -22,12 +22,12 @@ jobs:
     steps:
     - uses: actions/checkout@v2
 
-    - name: Set GGML_NTHREADS for Ubuntu
-      run: echo "GGML_NTHREADS=2" >> $GITHUB_ENV
+    - name: Set GGML_N_THREADS for Ubuntu
+      run: echo "GGML_N_THREADS=2" >> $GITHUB_ENV
       if: matrix.os == 'ubuntu-latest'
 
-    - name: Set GGML_NTHREADS for MacOS
-      run: echo "GGML_NTHREADS=3" >> $GITHUB_ENV
+    - name: Set GGML_N_THREADS for MacOS
+      run: echo "GGML_N_THREADS=2" >> $GITHUB_ENV
       if: matrix.os == 'macos-latest'
 
     - name: Create Build Environment
index b6371395239339772694af90a13c6d8b27a17212..b5a499c1db57e4dae69d92df31f5b4f32498b16a 100644 (file)
@@ -202,8 +202,23 @@ bool check_gradient(
         float max_error_abs,
         float max_error_rel) {
 
+    static int n_threads = -1;
+    if (n_threads < 0) {
+        n_threads = GGML_DEFAULT_N_THREADS;
+
+        const char *env = getenv("GGML_N_THREADS");
+        if (env) {
+            n_threads = atoi(env);
+        }
+
+        printf("GGML_N_THREADS = %d\n", n_threads);
+    }
+
     struct ggml_cgraph gf = ggml_build_forward (f);
+    gf.n_threads = n_threads;
+
     struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
+    gb.n_threads = n_threads;
 
     ggml_graph_compute(ctx0, &gf);
     ggml_graph_reset  (&gf);