]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Add `--no-op-offload` to improve `-ot` pp perf in MoE models like llama4 400B (llama...
authorDavid Huang <redacted>
Sun, 11 May 2025 12:18:39 +0000 (20:18 +0800)
committerGeorgi Gerganov <redacted>
Tue, 13 May 2025 10:02:19 +0000 (13:02 +0300)
include/ggml-backend.h
src/ggml-backend.cpp
tests/test-opt.cpp

index ea2c1a402cca102f5c4b3efd9e8a9b0e7593b443..778927f68217ae2eaad89b080a3f03829b4321e0 100644 (file)
@@ -248,7 +248,7 @@ extern "C" {
         // preferrably to run on the same backend as the buffer
         ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
 
-        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
+        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true);
 
         // initialize buffers from a max size graph (optional)
         reserve_graph = build_graph(sched, max_batch_size);
@@ -289,7 +289,7 @@ extern "C" {
     typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
 
     // Initialize a backend scheduler, backends with low index are given priority over backends with high index
-    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
+    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload);
     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
 
     // Initialize backend buffers from a measure graph
index c36b5abfb74224cfdfcd6d3ec9a1c4eb612aee78..6f69d895f170d039a1dcea8f89403c26e64ab8b5 100644 (file)
@@ -674,6 +674,8 @@ struct ggml_backend_sched {
     char * context_buffer;
     size_t context_buffer_size;
 
+    bool op_offload;
+
     int debug;
 };
 
@@ -766,7 +768,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
         if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
             int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
             // check if a backend with higher prio wants to offload the op
-            if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
+            if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
                 for (int b = 0; b < src_backend_id; b++) {
                     if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
                         SET_CAUSE(tensor, "1.off");
@@ -1452,7 +1454,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
         ggml_backend_buffer_type_t * bufts,
         int n_backends,
         size_t graph_size,
-        bool parallel) {
+        bool parallel,
+        bool op_offload) {
     GGML_ASSERT(n_backends > 0);
     GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
     GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
@@ -1497,6 +1500,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
     }
 
     sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
+    sched->op_offload = op_offload;
 
     ggml_backend_sched_reset(sched);
 
index f90c92b4b8ecfc8eff87c8dbbb9b368c55b8b1ce..1bc160511357186c827287e6a7a9c9f59cc47764 100644 (file)
@@ -853,7 +853,7 @@ int main(void) {
         backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
 
         ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
-            backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false);
+            backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);
 
         printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
         printf("  Device description: %s\n", ggml_backend_dev_description(devs[i]));