]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add ggml_row_size() (fixes llama out of space) (#4461)
authorLostRuins <redacted>
Thu, 14 Dec 2023 12:13:33 +0000 (20:13 +0800)
committerGitHub <redacted>
Thu, 14 Dec 2023 12:13:33 +0000 (14:13 +0200)
* Fixes "Not enough space in the context's memory pool" encountered on certain models, which seems to be caused by some imprecision related to the automatic casting of floating point values

* do not cast to size_t, instead just use doubles

* ggml : add ggml_row_size(), deprecate ggml_type_sizef()

* ggml : fix row size compute to avoid overflows

* tests : fix sizey -> sizez

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/benchmark/benchmark-matmult.cpp
ggml.c
ggml.h
llama.cpp

index 284733b1035c96b59d79ba45a3bdabb41f8df5ca..434e1d6bd509eaa155a58890f393a8d224575923 100644 (file)
@@ -129,13 +129,13 @@ int main(int argc, char ** argv)  {
     const ggml_type qtype = GGML_TYPE_Q4_1;
 
     size_t ctx_size = 0;
-    ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32);
-    ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32);
-    ctx_size += sizex*sizez*ggml_type_sizef(GGML_TYPE_F32);
-    ctx_size += sizex*sizey*ggml_type_sizef(qtype);
-    ctx_size += sizex*sizey*ggml_type_sizef(qtype);
-    ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS
-    ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS
+    ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey);
+    ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey);
+    ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizez);
+    ctx_size += ggml_row_size(qtype,         sizex*sizey);
+    ctx_size += ggml_row_size(qtype,         sizex*sizey);
+    ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS
+    ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS
     ctx_size += 1024*1024*16;
 
     printf("Allocating Memory of size %zi bytes, %zi MB\n",ctx_size, (ctx_size/1024/1024));
diff --git a/ggml.c b/ggml.c
index 7e1272817388c57feda796cd2d6e7f770fad83fa..f0a972690aea5e12677169118e44f5a4a5726456 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2011,8 +2011,13 @@ size_t ggml_type_size(enum ggml_type type) {
     return type_traits[type].type_size;
 }
 
-float ggml_type_sizef(enum ggml_type type) {
-    return ((float)(type_traits[type].type_size))/type_traits[type].blck_size;
+size_t ggml_row_size(enum ggml_type type, int64_t ne) {
+    assert(ne % ggml_blck_size(type) == 0);
+    return ggml_type_size(type)*ne/ggml_blck_size(type);
+}
+
+double ggml_type_sizef(enum ggml_type type) {
+    return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
 }
 
 const char * ggml_type_name(enum ggml_type type) {
diff --git a/ggml.h b/ggml.h
index 1447646b13c439a94026cf9af1cc4b42aa617fef..ae8101fab3d6d023ae69db3ff1cdb853e9adf9b9 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -641,9 +641,13 @@ extern "C" {
     GGML_API size_t  ggml_nbytes_pad  (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
     GGML_API size_t  ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
 
-    GGML_API int     ggml_blck_size (enum ggml_type type);
-    GGML_API size_t  ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
-    GGML_API float   ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
+    GGML_API int    ggml_blck_size(enum ggml_type type);
+    GGML_API size_t ggml_type_size(enum ggml_type type);             // size in bytes for all elements in a block
+    GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
+
+    GGML_DEPRECATED(
+    GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
+    "use ggml_row_size() instead");
 
     GGML_API const char * ggml_type_name(enum ggml_type type);
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
index 0e5ab044cdfa4fa8c471d76846726abeaa4bafbc..456807d9d5a3a21339912f3b86106b10829d59b8 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1555,7 +1555,7 @@ static bool llama_kv_cache_init(
     cache.cells.clear();
     cache.cells.resize(n_ctx);
 
-    cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead());
+    cache.buf.resize(ggml_row_size(ktype, n_elements) + ggml_row_size(vtype, n_elements) + 2u*n_layer*ggml_tensor_overhead());
     memset(cache.buf.data, 0, cache.buf.size);
 
     struct ggml_init_params params;
@@ -3822,8 +3822,8 @@ static void llm_build_k_shift(
             ggml_rope_custom_inplace(ctx,
                     ggml_view_3d(ctx, kv.k_l[il],
                         n_embd_head, n_head_kv, n_ctx,
-                        ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
-                        ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
+                        ggml_row_size(kv.k_l[il]->type, n_embd_head),
+                        ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
                         0),
                     K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow);
@@ -3852,7 +3852,7 @@ static void llm_build_kv_store(
     cb(v_cur_t, "v_cur_t", il);
 
     struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
-            (ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head);
+            (ggml_row_size(kv.k_l[il]->type, n_embd_gqa))*kv_head);
     cb(k_cache_view, "k_cache_view", il);
 
     struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
@@ -4011,8 +4011,8 @@ static struct ggml_tensor * llm_build_kqv(
     struct ggml_tensor * k =
         ggml_view_3d(ctx, kv.k_l[il],
                 n_embd_head, n_kv, n_head_kv,
-                ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
-                ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
+                ggml_row_size(kv.k_l[il]->type, n_embd_gqa),
+                ggml_row_size(kv.k_l[il]->type, n_embd_head),
                 0);
     cb(k, "k", il);