]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : alloc ggml_contexts on the heap (#2525)
authorGeorgi Gerganov <redacted>
Thu, 31 Oct 2024 20:00:09 +0000 (22:00 +0200)
committerGitHub <redacted>
Thu, 31 Oct 2024 20:00:09 +0000 (22:00 +0200)
* whisper : reduce ggml_context usage

* ggml : allocate contexts on the heap (v2)

* ggml : aligned malloc -> malloc

ggml/include/ggml.h
ggml/src/ggml-metal.m
ggml/src/ggml.c
src/whisper.cpp

index e7678d07135ab7c6c47c8cf8b00a7fbd23748c06..8d36b3d4d42e38aaed193e74636d4c38c9491897 100644 (file)
 
 #define GGML_MAX_DIMS           4
 #define GGML_MAX_PARAMS         2048
-#define GGML_MAX_CONTEXTS       64
 #define GGML_MAX_SRC            10
 #define GGML_MAX_N_THREADS      512
 #define GGML_MAX_OP_PARAMS      64
@@ -657,6 +656,7 @@ extern "C" {
     };
 
     // scratch buffer
+    // TODO: deprecate and remove
     struct ggml_scratch {
         size_t offs;
         size_t size;
@@ -760,8 +760,9 @@ extern "C" {
 
     // main
 
-    GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
-    GGML_API void                  ggml_free(struct ggml_context * ctx);
+    GGML_API struct ggml_context * ggml_init (struct ggml_init_params params);
+    GGML_API void                  ggml_reset(struct ggml_context * ctx);
+    GGML_API void                  ggml_free (struct ggml_context * ctx);
 
     GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
 
index 7baee41749ae75920ed5d992c34afada7300adf3..7e0b866a99b938d15f5256fbb9c85781b3048b29 100644 (file)
@@ -3129,7 +3129,7 @@ static enum ggml_status ggml_metal_graph_compute(
 
 // default buffer
 static id<MTLDevice> g_backend_device = nil;
-static int g_backend_device_ref_count = 0;
+static int g_backend_device_ref_count = 0; // TODO: make thread-safe
 
 static id<MTLDevice> ggml_backend_metal_get_device(void) {
     if (g_backend_device == nil) {
index 03b832d0f2189ea224e8db91f89df60f1e7db81a..264ffb5195ebf0f7d525bb0ef5a58d2ccfc87c37 100644 (file)
@@ -308,6 +308,7 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
 }
 
 #define GGML_DEBUG 0
+
 #define GGML_GELU_FP16
 #define GGML_GELU_QUICK_FP16
 
@@ -1985,7 +1986,7 @@ static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
 
 struct ggml_context {
     size_t mem_size;
-    void* mem_buffer;
+    void * mem_buffer;
     bool   mem_buffer_owned;
     bool   no_alloc;
     bool   no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
@@ -3234,7 +3235,6 @@ struct ggml_numa_nodes {
 //
 
 struct ggml_state {
-    struct ggml_context_container contexts[GGML_MAX_CONTEXTS];
     struct ggml_numa_nodes numa;
 };
 
@@ -3816,17 +3816,12 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
             const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
 
             g_state = (struct ggml_state) {
-                /*.contexts =*/ { { 0 } },
                 /*.numa =*/ {
                     .n_nodes = 0,
                     .total_cpus = 0,
                 },
             };
 
-            for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) {
-                g_state.contexts[i].used = false;
-            }
-
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
             GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
@@ -3839,26 +3834,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         is_first_call = false;
     }
 
-    // find non-used context in g_state
-    struct ggml_context * ctx = NULL;
-
-    for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
-        if (!g_state.contexts[i].used) {
-            g_state.contexts[i].used = true;
-            ctx = &g_state.contexts[i].context;
-
-            GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i);
-            break;
-        }
-    }
-
-    if (ctx == NULL) {
-        GGML_PRINT_DEBUG("%s: no unused context found\n", __func__);
-
-        ggml_critical_section_end();
+    ggml_critical_section_end();
 
-        return NULL;
-    }
+    struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));
 
     // allow to call ggml_init with 0 size
     if (params.mem_size == 0) {
@@ -3886,42 +3864,31 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
     GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
 
-    ggml_critical_section_end();
-
     return ctx;
 }
 
-void ggml_free(struct ggml_context * ctx) {
+void ggml_reset(struct ggml_context * ctx) {
     if (ctx == NULL) {
         return;
     }
 
-    // make this function thread safe
-    ggml_critical_section_start();
-
-    bool found = false;
-
-    for (int i = 0; i < GGML_MAX_CONTEXTS; i++) {
-        if (&g_state.contexts[i].context == ctx) {
-            g_state.contexts[i].used = false;
-
-            GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n",
-                    __func__, i, ggml_used_mem(ctx));
-
-            if (ctx->mem_buffer_owned) {
-                GGML_ALIGNED_FREE(ctx->mem_buffer);
-            }
+    ctx->n_objects     = 0;
+    ctx->objects_begin = NULL;
+    ctx->objects_end   = NULL;
+    ctx->scratch       = (struct ggml_scratch) { 0, 0, NULL, };
+    ctx->scratch_save  = (struct ggml_scratch) { 0, 0, NULL, };
+}
 
-            found = true;
-            break;
-        }
+void ggml_free(struct ggml_context * ctx) {
+    if (ctx == NULL) {
+        return;
     }
 
-    if (!found) {
-        GGML_PRINT_DEBUG("%s: context not found\n", __func__);
+    if (ctx->mem_buffer_owned) {
+        GGML_ALIGNED_FREE(ctx->mem_buffer);
     }
 
-    ggml_critical_section_end();
+    GGML_FREE(ctx);
 }
 
 size_t ggml_used_mem(const struct ggml_context * ctx) {
index 37214f0b507dd30e436a82fe7ba1449494025faa..6e62d103b17214228a040384350b47eef783ecb1 100644 (file)
@@ -699,9 +699,9 @@ struct whisper_kv_cache {
     struct ggml_tensor * k;
     struct ggml_tensor * v;
 
-    struct ggml_context * ctx = nullptr;
-
     ggml_backend_buffer_t buffer = nullptr;
+
+    std::vector<uint8_t> ctx_buf;
 };
 
 struct whisper_model {
@@ -941,9 +941,11 @@ static bool whisper_kv_cache_init(
     const int64_t n_mem      = n_text_layer*n_ctx;
     const int64_t n_elements = n_text_state*n_mem;
 
+    cache.ctx_buf.resize(2*ggml_tensor_overhead());
+
     struct ggml_init_params params = {
-        /*.mem_size   =*/ 2*ggml_tensor_overhead(),
-        /*.mem_buffer =*/ nullptr,
+        /*.mem_size   =*/ cache.ctx_buf.size(),
+        /*.mem_buffer =*/ cache.ctx_buf.data(),
         /*.no_alloc   =*/ true,
     };
 
@@ -953,17 +955,17 @@ static bool whisper_kv_cache_init(
     cache.cells.clear();
     cache.cells.resize(n_ctx);
 
-    cache.ctx = ggml_init(params);
+    struct ggml_context * ctx = ggml_init(params);
 
-    if (!cache.ctx) {
+    if (!ctx) {
         WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
         return false;
     }
 
-    cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
-    cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
+    cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
+    cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
 
-    cache.buffer = ggml_backend_alloc_ctx_tensors(cache.ctx, backend);
+    cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
     if (!cache.buffer) {
         WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
         return false;
@@ -971,13 +973,13 @@ static bool whisper_kv_cache_init(
 
     ggml_backend_buffer_clear(cache.buffer, 0);
 
+    ggml_free(ctx);
+
     return true;
 }
 
 static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
-    ggml_free(cache.ctx);
     ggml_backend_buffer_free(cache.buffer);
-    cache.ctx = nullptr;
 }
 
 static bool whisper_kv_cache_find_slot(
@@ -2002,7 +2004,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
 
     auto & kv_pad = wstate.kv_pad;
 
-    WHISPER_ASSERT(!!kv_pad.ctx);
+    WHISPER_ASSERT(!!kv_pad.buffer);
 
     const int n_ctx_pad = GGML_PAD(n_ctx, 256);
 
@@ -2416,7 +2418,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
 
     auto & kv_self = wstate.kv_self;
 
-    WHISPER_ASSERT(!!kv_self.ctx);
+    WHISPER_ASSERT(!!kv_self.buffer);
 
     const int n_ctx   = kv_self.size;
     const int n_state = hparams.n_text_state;