]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
train-text-from-scratch : fix assert failure in ggml-alloc (#3618)
authorslaren <redacted>
Tue, 17 Oct 2023 17:00:58 +0000 (19:00 +0200)
committerGitHub <redacted>
Tue, 17 Oct 2023 17:00:58 +0000 (20:00 +0300)
examples/train-text-from-scratch/train-text-from-scratch.cpp

index be693b3ac7a43bc6bdf590b4d089378881f3373f..1ce6cef29cfd06952d6e2ff32c850ba365dae4dd 100644 (file)
@@ -253,13 +253,14 @@ static void init_model(struct my_llama_model * model) {
     set_param_model(model);
 
     // measure data size
-    struct ggml_allocr * alloc = NULL;
-    alloc = ggml_allocr_new_measure(tensor_alignment);
-    alloc_model(alloc, model);
+    size_t size = 0;
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        size += GGML_PAD(ggml_nbytes(t), tensor_alignment);
+    }
 
     // allocate data
-    model->data.resize(ggml_allocr_max_size(alloc) + tensor_alignment);
-    ggml_allocr_free(alloc);
+    struct ggml_allocr * alloc = NULL;
+    model->data.resize(size + tensor_alignment);
     alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
     alloc_model(alloc, model);
     ggml_allocr_free(alloc);
@@ -1094,11 +1095,9 @@ int main(int argc, char ** argv) {
     struct ggml_tensor * target_probs  = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab,  n_tokens, n_batch);
 
     // measure required memory for input tensors
-    alloc = ggml_allocr_new_measure(tensor_alignment);
-    ggml_allocr_alloc(alloc, tokens_input);
-    ggml_allocr_alloc(alloc, target_probs);
-    size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment;
-    ggml_allocr_free(alloc);
+    size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) +
+                            GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) +
+                            tensor_alignment;
     printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
 
     // allocate input tensors