]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper : load the model into multiple buffers of max size 1GB (#1763)
authorGeorgi Gerganov <redacted>
Sat, 13 Jan 2024 15:47:40 +0000 (17:47 +0200)
committerGitHub <redacted>
Sat, 13 Jan 2024 15:47:40 +0000 (17:47 +0200)
whisper.cpp

index ca39b58ac0f9bc7b2db01858843508f58d070517..2d8a87e3ad67df8c83c8f529ed2beb67c15fe00d 100644 (file)
@@ -701,7 +701,7 @@ struct whisper_model {
     struct ggml_context * ctx;
 
     // the model backend data is read-only and can be shared between processors
-    struct ggml_backend_buffer * buffer;
+    std::vector<struct ggml_backend_buffer *> buffers;
 
     // tensors
     int n_loaded;
@@ -1514,24 +1514,64 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
 
     wctx.backend = whisper_backend_init(wctx.params);
 
+    // some devices have a limit on the maximum size of single memory buffer
+    // for example, iPhones are limited to 1GB per buffer
+    // to workaround this, we will allocate multiple buffers of smaller size and will split the tensors with the
+    // model weights between them
+    //
+    // the map_t2b maps tensor names to buffer indices
+    // as we iterate over the tensors, we will allocate new buffers when the current one is full
+    //
+    // finally, we create a separate allocator for each buffer and use it to allocate the tensors
+    // we keep the allocators alive until all the tensors are loaded
+
+    GGML_ASSERT(model.buffers.empty());
+
+    std::map<std::string, int> map_t2b;
+
     {
         size_t size_main = 0;
+        size_t size_cur  = 0;
+
+        static const size_t GB = 1024ull*1024ull*1024ull;
 
         for (const auto & t : model.tensors) {
-            size_main += ggml_nbytes(t.second) + ggml_tensor_overhead();
+            const size_t cur = ggml_nbytes(t.second) + ggml_tensor_overhead();
+
+            // adding the tensor to the current buffer will exceed the limit, so we need to allocate a new buffer
+            if (size_cur + cur > GB) {
+                GGML_ASSERT(size_cur > 0 && "A tensor is too large to fit in a single buffer");
+
+                model.buffers.emplace_back(ggml_backend_alloc_buffer(wctx.backend, size_cur));
+
+                size_cur = cur;
+            }
+
+            map_t2b[t.first] = model.buffers.size();
+
+            size_cur  += cur;
+            size_main += cur;
+        }
+
+        // allocate the last buffer if needed
+        if (size_cur > 0) {
+            model.buffers.emplace_back(ggml_backend_alloc_buffer(wctx.backend, size_cur));
         }
 
-        model.buffer = ggml_backend_alloc_buffer(wctx.backend, size_main);
+        GGML_ASSERT(model.buffers.size() > 0);
 
-        WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6);
+        WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB (%d buffers)\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6, (int) model.buffers.size());
     }
 
-    ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer);
+    std::vector<ggml_allocr *> allocs(model.buffers.size());
+    for (size_t i = 0; i < allocs.size(); ++i) {
+        allocs[i] = ggml_allocr_new_from_buffer(model.buffers[i]);
+    }
 
     // allocate tensors in the backend buffers
     {
         for (const auto & t : model.tensors) {
-            ggml_allocr_alloc(alloc, t.second);
+            ggml_allocr_alloc(allocs[map_t2b[t.first]], t.second);
         }
     }
 
@@ -1632,7 +1672,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
-    ggml_allocr_free(alloc);
+    for (auto & alloc : allocs) {
+        ggml_allocr_free(alloc);
+    }
 
     wctx.t_load_us = ggml_time_us() - t_start_us;
 
@@ -3376,8 +3418,10 @@ void whisper_free(struct whisper_context * ctx) {
             ggml_free(ctx->model.ctx);
         }
 
-        if (ctx->model.buffer) {
-            ggml_backend_buffer_free(ctx->model.buffer);
+        for (auto & buffer : ctx->model.buffers) {
+            if (buffer) {
+                ggml_backend_buffer_free(buffer);
+            }
         }
 
         whisper_free_state(ctx->state);