]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
test-backend-ops : extend test case filtering (#14865)
authorLeonard Mosescu <redacted>
Mon, 28 Jul 2025 16:04:27 +0000 (09:04 -0700)
committerGitHub <redacted>
Mon, 28 Jul 2025 16:04:27 +0000 (18:04 +0200)
* Extend test case filtering

1. Allow passing multiple (comma-separated?) ops to test-backend-ops. This can be convenient when working on a set of ops, when you'd want to test them together (but without having to run every single op). For example:

`test-backend-ops.exe test -o "ADD,RMS_NORM,ROPE,SILU,SOFT_MAX"`

2. Support full test-case variation string in addition to basic op names. This would make it easy to select a single variation, either for testing or for benchmarking. It can be particularly useful for profiling a particular variation (ex. a CUDA kernel), for example:

`test-backend-ops.exe perf -b CUDA0 -o "MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=2)"`

These two can be combined. As the current `-o`, this change doesn't try to detect/report an error if an filter doesn't name existing ops (ex. misspelled)

* Updating the usage help text

* Update tests/test-backend-ops.cpp

tests/test-backend-ops.cpp

index 7fb02a78899a63a1a4fa4547a7ebf3f145930ac2..3cc318a9cf96004871fd2c6fe2697e4b7dc6cd98 100644 (file)
@@ -35,6 +35,7 @@
 #include <random>
 #include <regex>
 #include <string>
+#include <string_view>
 #include <thread>
 #include <vector>
 
@@ -1047,7 +1048,37 @@ struct test_case {
         return t;
     }
 
-    bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) {
+    // Checks an op against the test filter, which is a comma separated list of OP names or specific variations
+    bool matches_filter(ggml_tensor * op, const char * op_names_filter) {
+        if (op_names_filter) {
+            const auto op_name = op_desc(op);
+            const auto op_full_name = op_name + "(" + vars() + ")";
+            std::string_view filter(op_names_filter);
+            while (!filter.empty()) {
+                auto comma_pos = filter.find_first_of(',');
+                const auto lparen_pos = filter.find_first_of('(');
+                if (lparen_pos < comma_pos) {
+                    auto rparen_pos = filter.find_first_of(')');
+                    comma_pos = filter.find_first_of(',', rparen_pos);
+                    const auto op_filter = filter.substr(0, comma_pos);
+                    if (op_filter == op_full_name) {
+                        return true;
+                    }
+                } else {
+                    const auto op_filter = filter.substr(0, comma_pos);
+                    if (op_filter == op_name) {
+                        return true;
+                    }
+                }
+                filter = comma_pos != std::string_view::npos ? filter.substr(comma_pos + 1) : "";
+            }
+            return false;
+        } else {
+            return true;
+        }
+    }
+
+    bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
         mode = MODE_TEST;
 
         ggml_init_params params = {
@@ -1065,7 +1096,7 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx);
         std::string current_op_name = op_desc(out);
-        if (op_name != nullptr && current_op_name != op_name) {
+        if (!matches_filter(out, op_names_filter)) {
             //printf("  %s: skipping\n", op_desc(out).c_str());
             ggml_free(ctx);
             return true;
@@ -1212,7 +1243,7 @@ struct test_case {
         return test_passed;
     }
 
-    bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) {
+    bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
         mode = MODE_PERF;
 
         static const size_t graph_nodes = 8192;
@@ -1227,7 +1258,7 @@ struct test_case {
 
         ggml_tensor * out             = build_graph(ctx.get());
         std::string   current_op_name = op_desc(out);
-        if (op_name != nullptr && current_op_name != op_name) {
+        if (!matches_filter(out, op_names_filter)) {
             //printf("  %s: skipping\n", op_desc(out).c_str());
             return true;
         }
@@ -1342,7 +1373,7 @@ struct test_case {
         return true;
     }
 
-    bool eval_support(ggml_backend_t backend, const char * op_name, printer * output_printer) {
+    bool eval_support(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
         mode = MODE_SUPPORT;
 
         static const size_t graph_nodes = 8192;
@@ -1357,7 +1388,7 @@ struct test_case {
 
         ggml_tensor * out             = build_graph(ctx.get());
         std::string   current_op_name = op_desc(out);
-        if (op_name != nullptr && current_op_name != op_name) {
+        if (!matches_filter(out, op_names_filter)) {
             return true;
         }
 
@@ -1374,7 +1405,7 @@ struct test_case {
         return true;
     }
 
-    bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) {
+    bool eval_grad(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
         mode = MODE_GRAD;
         const std::vector<float> expect = grad_expect();
 
@@ -1391,7 +1422,7 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx.get());
 
-        if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
+        if (!matches_filter(out, op_names_filter) || out->op == GGML_OP_OPT_STEP_ADAMW) {
             return true;
         }
 
@@ -5922,7 +5953,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     return test_cases;
 }
 
-static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter,
+static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,
                          printer * output_printer) {
     auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
         if (params_filter == nullptr) {
@@ -5954,7 +5985,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
         size_t n_ok = 0;
         for (auto & test : test_cases) {
-            if (test->eval(backend, backend_cpu, op_name, output_printer)) {
+            if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) {
                 n_ok++;
             }
         }
@@ -5970,7 +6001,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         filter_test_cases(test_cases, params_filter);
         size_t n_ok = 0;
         for (auto & test : test_cases) {
-            if (test->eval_grad(backend, op_name, output_printer)) {
+            if (test->eval_grad(backend, op_names_filter, output_printer)) {
                 n_ok++;
             }
         }
@@ -5983,7 +6014,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         auto test_cases = make_test_cases_perf();
         filter_test_cases(test_cases, params_filter);
         for (auto & test : test_cases) {
-            test->eval_perf(backend, op_name, output_printer);
+            test->eval_perf(backend, op_names_filter, output_printer);
         }
         return true;
     }
@@ -5992,7 +6023,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         auto test_cases = make_test_cases_eval();
         filter_test_cases(test_cases, params_filter);
         for (auto & test : test_cases) {
-            test->eval_support(backend, op_name, output_printer);
+            test->eval_support(backend, op_names_filter, output_printer);
         }
         return true;
     }
@@ -6001,20 +6032,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 }
 
 static void usage(char ** argv) {
-    printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
+    printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
     printf("    valid modes:\n");
     printf("      - test (default, compare with CPU backend for correctness)\n");
     printf("      - grad (compare gradients from backpropagation with method of finite differences)\n");
     printf("      - perf (performance evaluation)\n");
     printf("      - support (probe backend operation support)\n");
-    printf("    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
+    printf("    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc),\n");
+    printf("        optionally including the full test case string (e.g. \"ADD(type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1)\")\n");
     printf("    --output specifies output format (default: console, options: console, sql, csv)\n");
 }
 
 int main(int argc, char ** argv) {
     test_mode mode = MODE_TEST;
     output_formats output_format = CONSOLE;
-    const char * op_name_filter = nullptr;
+    const char * op_names_filter = nullptr;
     const char * backend_filter = nullptr;
     const char * params_filter = nullptr;
 
@@ -6029,7 +6061,7 @@ int main(int argc, char ** argv) {
             mode = MODE_SUPPORT;
         } else if (strcmp(argv[i], "-o") == 0) {
             if (i + 1 < argc) {
-                op_name_filter = argv[++i];
+                op_names_filter = argv[++i];
             } else {
                 usage(argv);
                 return 1;
@@ -6110,7 +6142,7 @@ int main(int argc, char ** argv) {
                                                              false, "", ggml_backend_dev_description(dev),
                                                              total / 1024 / 1024, free / 1024 / 1024, true));
 
-        bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get());
+        bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get());
 
         if (ok) {
             n_ok++;