]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Docs: script to auto-generate ggml operations docs (llama/14598)
authorAman Gupta <redacted>
Thu, 10 Jul 2025 15:29:01 +0000 (23:29 +0800)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
* Docs: script to auto-generate ggml operations docs

* Review: formatting changes + change github action

* Use built-in types instead of typing

* docs : add BLAS and Metal ops

---------

Co-authored-by: Georgi Gerganov <redacted>
tests/test-backend-ops.cpp

index 4eeeb6e43a40027dbfad22b52270177f89dc2dae..a743ad6e8c10dc8783f3532eed8947cfaa1e800f 100644 (file)
@@ -317,10 +317,11 @@ enum test_mode {
     MODE_TEST,
     MODE_PERF,
     MODE_GRAD,
+    MODE_SUPPORT,
 };
 
 // Output format support similar to llama-bench
-enum output_formats { CONSOLE, SQL };
+enum output_formats { CONSOLE, SQL, CSV };
 
 static const char * output_format_str(output_formats format) {
     switch (format) {
@@ -328,6 +329,8 @@ static const char * output_format_str(output_formats format) {
             return "console";
         case SQL:
             return "sql";
+        case CSV:
+            return "csv";
         default:
             GGML_ABORT("invalid output format");
     }
@@ -338,6 +341,8 @@ static bool output_format_from_str(const std::string & s, output_formats & forma
         format = CONSOLE;
     } else if (s == "sql") {
         format = SQL;
+    } else if (s == "csv") {
+        format = CSV;
     } else {
         return false;
     }
@@ -360,6 +365,8 @@ struct test_result {
     double      bandwidth_gb_s;
     size_t      memory_kb;
     int         n_runs;
+    std::string device_description;
+    std::string backend_reg_name;
 
     test_result() {
         // Initialize with default values
@@ -384,7 +391,7 @@ struct test_result {
     test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params,
                 const std::string & test_mode, bool supported, bool passed, const std::string & error_message = "",
                 double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0,
-                int n_runs = 0) :
+                int n_runs = 0, const std::string & device_description = "", const std::string & backend_reg_name = "") :
         backend_name(backend_name),
         op_name(op_name),
         op_params(op_params),
@@ -396,7 +403,9 @@ struct test_result {
         flops(flops),
         bandwidth_gb_s(bandwidth_gb_s),
         memory_kb(memory_kb),
-        n_runs(n_runs) {
+        n_runs(n_runs),
+        device_description(device_description),
+        backend_reg_name(backend_reg_name) {
         // Set test time
         time_t t = time(NULL);
         char   buf[32];
@@ -410,7 +419,8 @@ struct test_result {
     static const std::vector<std::string> & get_fields() {
         static const std::vector<std::string> fields = {
             "test_time", "build_commit",  "backend_name", "op_name", "op_params",      "test_mode", "supported",
-            "passed",    "error_message", "time_us",      "flops",   "bandwidth_gb_s", "memory_kb", "n_runs"
+            "passed",    "error_message", "time_us",      "flops",   "bandwidth_gb_s", "memory_kb", "n_runs",
+            "device_description", "backend_reg_name"
         };
         return fields;
     }
@@ -444,7 +454,9 @@ struct test_result {
                  std::to_string(flops),
                  std::to_string(bandwidth_gb_s),
                  std::to_string(memory_kb),
-                 std::to_string(n_runs) };
+                 std::to_string(n_runs),
+                 device_description,
+                 backend_reg_name };
     }
 };
 
@@ -633,6 +645,8 @@ struct console_printer : public printer {
             print_test_console(result);
         } else if (result.test_mode == "perf") {
             print_perf_console(result);
+        } else if (result.test_mode == "support") {
+            print_support_console(result);
         }
     }
 
@@ -799,6 +813,17 @@ struct console_printer : public printer {
         }
         printf("\n");
     }
+
+    void print_support_console(const test_result & result) {
+        printf("  %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
+        fflush(stdout);
+
+        if (result.supported) {
+            printf("\033[1;32mSUPPORTED\033[0m\n");
+        } else {
+            printf("\033[1;31mNOT SUPPORTED\033[0m\n");
+        }
+    }
 };
 
 struct sql_printer : public printer {
@@ -841,12 +866,39 @@ struct sql_printer : public printer {
     }
 };
 
+struct csv_printer : public printer {
+    void print_header() override {
+        std::vector<std::string> fields = test_result::get_fields();
+        for (size_t i = 0; i < fields.size(); i++) {
+            printf("\"%s\"%s", fields[i].c_str(), i < fields.size() - 1 ? "," : "");
+        }
+        printf("\n");
+    }
+
+    void print_test_result(const test_result & result) override {
+        std::vector<std::string> values = result.get_values();
+        for (size_t i = 0; i < values.size(); i++) {
+            // Escape quotes and wrap in quotes for CSV
+            std::string escaped_value = values[i];
+            size_t pos = 0;
+            while ((pos = escaped_value.find("\"", pos)) != std::string::npos) {
+                escaped_value.replace(pos, 1, "\"\"");
+                pos += 2;
+            }
+            printf("\"%s\"%s", escaped_value.c_str(), i < values.size() - 1 ? "," : "");
+        }
+        printf("\n");
+    }
+};
+
 static std::unique_ptr<printer> create_printer(output_formats format) {
     switch (format) {
         case CONSOLE:
             return std::make_unique<console_printer>();
         case SQL:
             return std::make_unique<sql_printer>();
+        case CSV:
+            return std::make_unique<csv_printer>();
     }
     GGML_ABORT("invalid output format");
 }
@@ -928,7 +980,7 @@ struct test_case {
     std::vector<ggml_tensor *> sentinels;
 
     void add_sentinel(ggml_context * ctx) {
-        if (mode == MODE_PERF || mode == MODE_GRAD) {
+        if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {
             return;
         }
         ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
@@ -1153,15 +1205,12 @@ struct test_case {
             return true;
         }
 
-        // check if backends support op
         if (!ggml_backend_supports_op(backend, out)) {
             // Create test result for unsupported performance test
             test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false,
                                "not supported");
 
-            if (output_printer) {
-                output_printer->print_test_result(result);
-            }
+            output_printer->print_test_result(result);
 
             return true;
         }
@@ -1266,6 +1315,38 @@ struct test_case {
         return true;
     }
 
+    bool eval_support(ggml_backend_t backend, const char * op_name, printer * output_printer) {
+        mode = MODE_SUPPORT;
+
+        static const size_t graph_nodes = 8192;
+
+        ggml_init_params params = {
+            /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
+            /* .mem_base = */ NULL,
+            /* .no_alloc = */ true,
+        };
+        ggml_context_ptr ctx(ggml_init(params)); // smart ptr
+        GGML_ASSERT(ctx);
+
+        ggml_tensor * out             = build_graph(ctx.get());
+        std::string   current_op_name = op_desc(out);
+        if (op_name != nullptr && current_op_name != op_name) {
+            return true;
+        }
+
+        bool supported = ggml_backend_supports_op(backend, out);
+
+        std::string device_desc = ggml_backend_dev_description(ggml_backend_get_device(backend));
+        std::string backend_reg_name = ggml_backend_reg_name(ggml_backend_dev_backend_reg(ggml_backend_get_device(backend)));
+
+        test_result result(ggml_backend_name(backend), current_op_name, vars(), "support", supported, supported,
+                           supported ? "yes" : "no", 0.0, 0.0, 0.0, 0, 0, device_desc, backend_reg_name);
+
+        output_printer->print_test_result(result);
+
+        return true;
+    }
+
     bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) {
         mode = MODE_GRAD;
         const std::vector<float> expect = grad_expect();
@@ -5599,17 +5680,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         return true;
     }
 
+    if (mode == MODE_SUPPORT) {
+        auto test_cases = make_test_cases_eval();
+        filter_test_cases(test_cases, params_filter);
+        for (auto & test : test_cases) {
+            test->eval_support(backend, op_name, output_printer);
+        }
+        return true;
+    }
+
     GGML_ABORT("fatal error");
 }
 
 static void usage(char ** argv) {
-    printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql>]\n", argv[0]);
+    printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
     printf("    valid modes:\n");
     printf("      - test (default, compare with CPU backend for correctness)\n");
     printf("      - grad (compare gradients from backpropagation with method of finite differences)\n");
     printf("      - perf (performance evaluation)\n");
+    printf("      - support (probe backend operation support)\n");
     printf("    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
-    printf("    --output specifies output format (default: console)\n");
+    printf("    --output specifies output format (default: console, options: console, sql, csv)\n");
 }
 
 int main(int argc, char ** argv) {
@@ -5626,6 +5717,8 @@ int main(int argc, char ** argv) {
             mode = MODE_PERF;
         } else if (strcmp(argv[i], "grad") == 0) {
             mode = MODE_GRAD;
+        } else if (strcmp(argv[i], "support") == 0) {
+            mode = MODE_SUPPORT;
         } else if (strcmp(argv[i], "-o") == 0) {
             if (i + 1 < argc) {
                 op_name_filter = argv[++i];