]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
test-backend-ops : use flops for some performance tests (#9657)
authorslaren <redacted>
Sat, 28 Sep 2024 12:32:46 +0000 (14:32 +0200)
committerGitHub <redacted>
Sat, 28 Sep 2024 12:32:46 +0000 (14:32 +0200)
* test-backend-ops : use flops for some performance tests

- parallelize tensor quantization

- use a different set of cases for performance and correctness tests

- run each test for at least one second

tests/test-backend-ops.cpp

index 9a96cfc4c99de05e48de79dd929f4039f60ac3ea..d2cfe06b592cfc25a2753bc63462c1947d2ca092 100644 (file)
 #include <stdlib.h>
 #include <string>
 #include <thread>
+#include <future>
 #include <vector>
 
 static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
-    // static RNG initialization (revisit if n_threads stops being constant)
-    static const size_t n_threads = std::thread::hardware_concurrency();
-    static std::vector<std::default_random_engine> generators = []() {
-        std::random_device rd;
-        std::vector<std::default_random_engine> vec;
-        vec.reserve(n_threads);
-        //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
-        for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
-        return vec;
-    }();
-
-    size_t size = ggml_nelements(tensor);
-    std::vector<float> data(size);
+    size_t nels = ggml_nelements(tensor);
+    std::vector<float> data(nels);
+    {
+        // parallel initialization
+        static const size_t n_threads = std::thread::hardware_concurrency();
+        // static RNG initialization (revisit if n_threads stops being constant)
+        static std::vector<std::default_random_engine> generators = []() {
+            std::random_device rd;
+            std::vector<std::default_random_engine> vec;
+            vec.reserve(n_threads);
+            //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
+            for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
+            return vec;
+        }();
+
+        auto init_thread = [&](size_t ith, size_t start, size_t end) {
+            std::uniform_real_distribution<float> distribution(min, max);
+            auto & gen = generators[ith];
+            for (size_t i = start; i < end; i++) {
+                data[i] = distribution(gen);
+            }
+        };
 
-    auto init_thread = [&](size_t ith, size_t start, size_t end) {
-        std::uniform_real_distribution<float> distribution(min, max);
-        for (size_t i = start; i < end; i++) {
-            data[i] = distribution(generators[ith]);
+        std::vector<std::future<void>> tasks;
+        tasks.reserve(n_threads);
+        for (size_t i = 0; i < n_threads; i++) {
+            size_t start =     i*nels/n_threads;
+            size_t end   = (i+1)*nels/n_threads;
+            tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
         }
-    };
-
-    std::vector<std::thread> threads;
-    threads.reserve(n_threads);
-    for (size_t i = 0; i < n_threads; i++) {
-        size_t start =     i*size/n_threads;
-        size_t end   = (i+1)*size/n_threads;
-        threads.emplace_back(init_thread, i, start, end);
-    }
-    for (auto & t : threads) {
-        t.join();
-    }
-
-#if 0
-    const char * val_str = getenv("GGML_TEST_EPS");
-    float val = 1e-9f;
-    if (val_str != nullptr) {
-        val = std::stof(val_str);
-        printf("GGML_TEST_EPS=%e\n", val);
-    }
-
-    // test quantization with very small values that may result in nan scales due to division by zero
-    if (ggml_is_quantized(tensor->type)) {
-        for (int i = 0; i < 256; i++) {
-            data[i] = val;
+        for (auto & t : tasks) {
+            t.get();
         }
     }
-#endif
 
     if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
-        ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
+        ggml_backend_tensor_set(tensor, data.data(), 0, nels * sizeof(float));
     } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
-        GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
-        std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size));
-        std::vector<float> imatrix(tensor->ne[0], 1.0f); // dummy importance matrix
+        GGML_ASSERT(nels % ggml_blck_size(tensor->type) == 0);
+
+         // dummy importance matrix
+        std::vector<float> imatrix(tensor->ne[0], 1.0f);
         const float * im = imatrix.data();
         if (!ggml_quantize_requires_imatrix(tensor->type)) {
             // when the imatrix is optional, we want to test both quantization with and without imatrix
@@ -98,15 +87,31 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
             }
         }
 
-        ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
-        GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
-        // TODO: other cases
-        //#pragma omp parallel for
-        //for (int i = 0; i < tensor->ne[1]; i++) {
-        //    ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
-        //        i * tensor->ne[0], 1, tensor->ne[0], im);
-        //}
-
+        std::vector<uint8_t> dataq(ggml_row_size(tensor->type, nels));
+        {
+            // parallel quantization by block
+            size_t blck_size = ggml_blck_size(tensor->type);
+            size_t n_blocks = nels / blck_size;
+
+            auto quantize_thread = [&](size_t start, size_t end) {
+                ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
+                    start * blck_size, end - start, blck_size, im);
+            };
+
+            const size_t min_blocks_per_thread = 1;
+            const size_t n_threads = std::min<size_t>(std::thread::hardware_concurrency()/2,
+                                                      std::max<size_t>(1, n_blocks / min_blocks_per_thread));
+            std::vector<std::future<void>> tasks;
+            tasks.reserve(n_threads);
+            for (size_t i = 0; i < n_threads; i++) {
+                size_t start =     i*n_blocks/n_threads;
+                size_t end   = (i+1)*n_blocks/n_threads;
+                tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
+            }
+            for (auto & t : tasks) {
+                t.get();
+            }
+        }
         ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
     } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
         // This is going to create some weird integers though.
@@ -160,60 +165,6 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
     return tv;
 }
 
-/*
-static double cosine_similarity(const float * v1, const float * v2, size_t n) {
-    double dot = 0.0;
-    double mag1 = 0.0;
-    double mag2 = 0.0;
-
-    for (size_t i = 0; i < n; i++) {
-        if (std::isnan(v1[i]) || std::isnan(v2[i])) {
-            return -1.0f;
-        }
-        if (std::isinf(v1[i]) && std::isinf(v2[i])) {
-            continue;
-        }
-        dot  += v1[i]*v2[i];
-        mag1 += v1[i]*v1[i];
-        mag2 += v2[i]*v2[i];
-    }
-
-    return dot/sqrt(mag1*mag2);
-}
-
-static float distance(const float * v1, const float * v2, size_t n) {
-    double d = 0.0;
-
-    for (size_t i = 0; i < n; i++) {
-        if (std::isnan(v1[i]) || std::isnan(v2[i])) {
-            return INFINITY;
-        }
-        if (std::isinf(v1[i]) && std::isinf(v2[i])) {
-            continue;
-        }
-        d += (v1[i] - v2[i])*(v1[i] - v2[i]);
-    }
-
-    return sqrt(d);
-}
-
-static float vec_len(const float * v, size_t n) {
-    double d = 0.0;
-
-    for (size_t i = 0; i < n; i++) {
-        if (std::isnan(v[i])) {
-            return INFINITY;
-        }
-        if (std::isinf(v[i])) {
-            continue;
-        }
-        d += v[i]*v[i];
-    }
-
-    return sqrt(d);
-}
-*/
-
 // normalized mean squared error = mse(a, b) / mse(a, 0)
 static double nmse(const float * a, const float * b, size_t n) {
     double mse_a_b = 0.0;
@@ -264,7 +215,6 @@ static double mean_abs_asymm(const float * a, const float * b, const size_t n, c
 }
 
 // utils for printing the variables of the test cases
-#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
 
 template<typename T>
 static std::string var_to_str(const T & x) {
@@ -297,10 +247,6 @@ static std::string var_to_str(const std::array<T, N> & x) {
     return s;
 }
 
-//static std::string var_to_str(ggml_unary_op unary_op) {
-//    return ggml_unary_op_name(unary_op);
-//}
-
 static std::string var_to_str(ggml_type type) {
     return ggml_type_name(type);
 }
@@ -313,6 +259,8 @@ static std::string var_to_str(ggml_op_pool pool) {
     }
 }
 
+#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
+
 #define VARS_TO_STR1(a) VAR_TO_STR(a)
 #define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
 #define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
@@ -370,13 +318,13 @@ struct test_case {
         return 1e-4;
     }
 
-    virtual float grad_eps(){
+    virtual float grad_eps() {
         return 1e-1f;
     }
 
     // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher.
     // If true,  estimate gradient with 4 points, neglects 5th order derivative and higher.
-    virtual bool grad_precise(){
+    virtual bool grad_precise() {
         return false;
     }
 
@@ -409,6 +357,11 @@ struct test_case {
         return size;
     }
 
+    virtual uint64_t op_flops(ggml_tensor * t) {
+        GGML_UNUSED(t);
+        return 0;
+    }
+
     ggml_cgraph * gf = nullptr;
     ggml_cgraph * gb = nullptr;
 
@@ -651,12 +604,11 @@ struct test_case {
         }
 
         // align while also leaving some margin for variations in parameters
-        int align = 20;
+        int align = 8;
         int last = (len + align - 1) / align * align;
         if (last - len < 5) {
             last += align;
         }
-        last = std::max(last, 60);
         printf("%*s", last - len, "");
 
         // allocate
@@ -677,9 +629,25 @@ struct test_case {
         // warmup run
         ggml_backend_graph_compute(backend, gf);
 
+        // determine number of runs
+        int n_runs;
+        if (op_flops(out) > 0) {
+            // based on flops
+            const uint64_t GFLOP = 1000 * 1000 * 1000;
+            const uint64_t target_flops_cpu =   8ULL * GFLOP;
+            const uint64_t target_flops_gpu = 100ULL * GFLOP;
+            uint64_t target_flops = ggml_backend_is_cpu(backend) ? target_flops_cpu : target_flops_gpu;
+            n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
+        } else {
+            // based on memory size
+            const size_t GB = 1ULL << 30;
+            const size_t target_size_cpu =  8 * GB;
+            const size_t target_size_gpu = 32 * GB;
+            size_t target_size = ggml_backend_is_cpu(backend) ? target_size_cpu : target_size_gpu;
+            n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
+        }
+
         // duplicate the op
-        size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
-        int n_runs = std::min((size_t) ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
         for (int i = 1; i < n_runs; i++) {
             ggml_graph_add_node(gf, out);
         }
@@ -706,17 +674,47 @@ struct test_case {
         // run
         ggml_backend_synchronize(backend);
 
-        int64_t start_time = ggml_time_us();
-        ggml_backend_graph_compute(backend, gf);
-        ggml_backend_synchronize(backend);
-        int64_t end_time = ggml_time_us();
-        double time_us = end_time - start_time;
+        int64_t total_time_us = 0;
+        int total_runs = 0;
+        do {
+            int64_t start_time = ggml_time_us();
+            ggml_backend_graph_compute(backend, gf);
+            ggml_backend_synchronize(backend);
+            int64_t end_time = ggml_time_us();
+
+            total_time_us += end_time - start_time;
+            total_runs += n_runs;
+        } while (total_time_us < 1000*1000); // run for at least 1 second
+
+        printf("    %8d runs - %8.2f us/run - ",
+            total_runs,
+            (double)total_time_us / total_runs);
+
+        if (op_flops(out) > 0) {
+            double flops_per_sec = (op_flops(out) * total_runs) / (total_time_us / 1e6);
+            auto format_flops = [](double flops) -> std::string {
+                char buf[256];
+                if (flops >= 1e12) {
+                    snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12);
+                } else if (flops >= 1e9) {
+                    snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9);
+                } else if (flops >= 1e6) {
+                    snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6);
+                } else {
+                    snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3);
+                }
+                return buf;
+            };
+            printf("%s/run - \033[1;34m%sS\033[0m",
+                format_flops(op_flops(out)).c_str(),
+                format_flops(flops_per_sec).c_str());
 
-        printf("    %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n",
-            n_runs,
-            time_us / n_runs,
-            op_size(out) / 1024,
-            mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0);
+        } else {
+            printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m",
+                op_size(out) / 1024,
+                mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0);
+        }
+        printf("\n");
 
         ggml_backend_buffer_free(buf);
 
@@ -1591,13 +1589,9 @@ struct test_mul_mat : public test_case {
         return 5e-4;
     }
 
-    size_t op_size(ggml_tensor * t) override {
-        size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1];
-        size_t b = ggml_nbytes(t->src[1]) * m;
-        size_t c  = ggml_nbytes(t);
-        return a + b + c;
-
+    uint64_t op_flops(ggml_tensor * t) override {
         GGML_UNUSED(t);
+        return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
     }
 
     test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
@@ -1641,13 +1635,9 @@ struct test_mul_mat_id : public test_case {
         return 5e-4;
     }
 
-    size_t op_size(ggml_tensor * t) override {
-        size_t a = ggml_nbytes(t->src[2]) * n;
-        size_t b = ggml_nbytes(t->src[1]) * m;
-        size_t c  = ggml_nbytes(t);
-        return a + b + c;
-
+    uint64_t op_flops(ggml_tensor * t) override {
         GGML_UNUSED(t);
+        return 2 * m * k * n * n_used;
     }
 
     test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
@@ -3163,47 +3153,46 @@ struct test_falcon : public test_llm {
 // ###########################################
 // ## Section 3: GGML Op Test Instantiation ##
 // ###########################################
+static const ggml_type all_types[] = {
+    GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
+    GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+    GGML_TYPE_Q8_0,
+    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+    GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+    GGML_TYPE_Q6_K,
+    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
+    GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
+    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
+    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+};
+
+static const ggml_type base_types[] = {
+    GGML_TYPE_F32, GGML_TYPE_F16,
+    GGML_TYPE_Q4_0,
+    GGML_TYPE_Q4_K,
+    GGML_TYPE_IQ2_XXS
+};
 
+static const ggml_type other_types[] = {
+    GGML_TYPE_Q4_1,
+    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+    GGML_TYPE_Q8_0,
+    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+    GGML_TYPE_Q5_K,
+    GGML_TYPE_Q6_K,
+    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
+    GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
+    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
+    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+    GGML_TYPE_BF16,
+};
 
-static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
+// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
+static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     std::vector<std::unique_ptr<test_case>> test_cases;
     std::default_random_engine rng(0);
 
-    const ggml_type all_types[] = {
-        GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
-        GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
-        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
-        GGML_TYPE_Q8_0,
-        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
-        GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
-        GGML_TYPE_Q6_K,
-        // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
-        GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
-        GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
-        GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
-    };
-
-    const ggml_type base_types[] = {
-        GGML_TYPE_F32, GGML_TYPE_F16,
-        GGML_TYPE_Q4_0,
-        GGML_TYPE_Q4_K,
-        GGML_TYPE_IQ2_XXS
-    };
-
-    const ggml_type other_types[] = {
-        GGML_TYPE_Q4_1,
-        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
-        GGML_TYPE_Q8_0,
-        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
-        GGML_TYPE_Q5_K,
-        GGML_TYPE_Q6_K,
-        // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
-        GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
-        GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
-        GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
-        GGML_TYPE_BF16,
-    };
-
     // unary ops
     for (int v : {0, 1}) {
         for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
@@ -3392,6 +3381,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
         }
     }
+    for (ggml_type type_a : other_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32}) {
+            if (ggml_blck_size(type_a) != 256) {
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1,  1}, {1, 1}));
+            }
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1,  1}, {1, 1}));
+        }
+    }
 #else
     // m = a rows
     // n = b rows
@@ -3411,15 +3408,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 #endif
 
-    for (ggml_type type_a : other_types) {
-        for (ggml_type type_b : {GGML_TYPE_F32}) {
-            if (ggml_blck_size(type_a) != 256) {
-                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1,  1}, {1, 1}));
-            }
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1,  1}, {1, 1}));
-        }
-    }
-
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,  128, { 8,  1}, {1, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,  128, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,   64, { 8,  1}, {4, 1}));
@@ -3624,20 +3612,30 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_falcon(2));
 #endif
 
-    // run tests
-    if (mode == MODE_GRAD) {
-        size_t n_ok = 0;
-        for (auto & test : test_cases) {
-            if (test->eval_grad(backend, op_name)) {
-                n_ok++;
+    return test_cases;
+}
+
+// Test cases for performance evaluation: should be representative of real-world use cases
+static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
+    std::vector<std::unique_ptr<test_case>> test_cases;
+
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1,   1, 1, 1}));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
+
+    for (int bs : {1, 512}) {
+        for (ggml_type type_a : all_types) {
+            for (ggml_type type_b : {GGML_TYPE_F32}) {
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1,  1}, {1, 1}));
             }
         }
-        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
-
-        return n_ok == test_cases.size();
     }
 
+    return test_cases;
+}
+
+static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
     if (mode == MODE_TEST) {
+        auto test_cases = make_test_cases_eval();
         ggml_backend_t backend_cpu = ggml_backend_cpu_init();
 
         size_t n_ok = 0;
@@ -3653,7 +3651,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         return n_ok == test_cases.size();
     }
 
+    if (mode == MODE_GRAD) {
+        auto test_cases = make_test_cases_eval();
+        size_t n_ok = 0;
+        for (auto & test : test_cases) {
+            if (test->eval_grad(backend, op_name)) {
+                n_ok++;
+            }
+        }
+        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
+
+        return n_ok == test_cases.size();
+    }
+
     if (mode == MODE_PERF) {
+        auto test_cases = make_test_cases_perf();
         for (auto & test : test_cases) {
             test->eval_perf(backend, op_name);
         }
@@ -3667,9 +3679,9 @@ static void usage(char ** argv) {
     printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
     printf("    valid modes:\n");
     printf("      - test (default, compare with CPU backend for correctness)\n");
-    printf("      - perf (performance evaluation)\n");
     printf("      - grad (compare gradients from backpropagation with method of finite differences)\n");
-    printf("    op names are as given by ggml_op_desc() (e.g. GGML_ADD)\n");
+    printf("      - perf (performance evaluation)\n");
+    printf("    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
 }
 
 int main(int argc, char ** argv) {
@@ -3728,6 +3740,11 @@ int main(int argc, char ** argv) {
             continue;
         }
 
+        if (ggml_backend_is_cpu(backend)) {
+            // TODO: better value for n_threads
+            ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
+        }
+
         printf("  Backend name: %s\n", ggml_backend_name(backend));
 
         bool ok = test_backend(backend, mode, op_name_filter);