]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
examples : make n_ctx warning work again (#3066)
authorCebtenzzre <redacted>
Fri, 8 Sep 2023 15:43:35 +0000 (11:43 -0400)
committerGitHub <redacted>
Fri, 8 Sep 2023 15:43:35 +0000 (11:43 -0400)
This was broken by commit e36ecdcc ("build : on Mac OS enable Metal by
default (#2901)").

examples/embedding/embedding.cpp
examples/main/main.cpp
examples/perplexity/perplexity.cpp
llama.cpp
llama.h

index 49ab3e0635abb9dac2ffa050eb4eef49e068b819..e4a0a38c831730401f4921a2f756ef9de3b5b2fa 100644 (file)
@@ -17,11 +17,6 @@ int main(int argc, char ** argv) {
 
     params.embedding = true;
 
-    if (params.n_ctx > 2048) {
-        fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
-                "expect poor results\n", __func__, params.n_ctx);
-    }
-
     fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
 
     if (params.seed == LLAMA_DEFAULT_SEED) {
@@ -47,6 +42,12 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    const int n_ctx_train = llama_n_ctx_train(ctx);
+    if (params.n_ctx > n_ctx_train) {
+        fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
+                __func__, n_ctx_train, params.n_ctx);
+    }
+
     // print system information
     {
         fprintf(stderr, "\n");
index be030fffbc1af315102ea5bd5dd9f0fe2e13d53b..baec6ba129da0b744fbef903d903d5719f50772f 100644 (file)
@@ -182,8 +182,10 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    if (params.n_ctx > llama_n_ctx(ctx)) {
-        LOG_TEE("%s: warning: base model only supports context sizes no greater than %d tokens (%d specified)\n", __func__, llama_n_ctx(ctx), params.n_ctx);
+    const int n_ctx_train = llama_n_ctx_train(ctx);
+    if (params.n_ctx > n_ctx_train) {
+        LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
+                __func__, n_ctx_train, params.n_ctx);
     } else if (params.n_ctx < 8) {
         LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
         params.n_ctx = 8;
index 1b760683b0b03fe737080a53d28bd7d2a4a828bc..3a1c8c28da09b322a83a76ccaad902067090785b 100644 (file)
@@ -693,9 +693,10 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    if (params.n_ctx > llama_n_ctx(ctx)) {
-        fprintf(stderr, "%s: warning: model might not support context sizes greater than %d tokens (%d specified);"
-                "expect poor results\n", __func__, llama_n_ctx(ctx), params.n_ctx);
+    const int n_ctx_train = llama_n_ctx_train(ctx);
+    if (params.n_ctx > n_ctx_train) {
+        fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
+                __func__, n_ctx_train, params.n_ctx);
     }
 
     // print system information
index 3f11902214da9c7b693c2f6ecbd8783ec4aeebf0..2a2a0c9c63cef1b529ff41f502a2579c321d81df 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -5633,15 +5633,19 @@ void llama_free(struct llama_context * ctx) {
 }
 
 int llama_n_vocab(const struct llama_context * ctx) {
-    return ctx->model.vocab.id_to_token.size();
+    return llama_model_n_vocab(&ctx->model);
 }
 
 int llama_n_ctx(const struct llama_context * ctx) {
-    return ctx->model.hparams.n_ctx;
+    return llama_model_n_ctx(&ctx->model);
+}
+
+int llama_n_ctx_train(const struct llama_context * ctx) {
+    return llama_model_n_ctx_train(&ctx->model);
 }
 
 int llama_n_embd(const struct llama_context * ctx) {
-    return ctx->model.hparams.n_embd;
+    return llama_model_n_embd(&ctx->model);
 }
 
 enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
@@ -5656,6 +5660,10 @@ int llama_model_n_ctx(const struct llama_model * model) {
     return model->hparams.n_ctx;
 }
 
+int llama_model_n_ctx_train(const struct llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
 int llama_model_n_embd(const struct llama_model * model) {
     return model->hparams.n_embd;
 }
diff --git a/llama.h b/llama.h
index 5b95aaa8776dd86eca2203205fef2bed1ea6bf51..37975bebed22e239b1f4d04ac1aae7b329ec9f93 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -245,15 +245,17 @@ extern "C" {
     LLAMA_API bool llama_mmap_supported (void);
     LLAMA_API bool llama_mlock_supported(void);
 
-    LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
-    LLAMA_API int llama_n_ctx  (const struct llama_context * ctx);
-    LLAMA_API int llama_n_embd (const struct llama_context * ctx);
+    LLAMA_API int llama_n_vocab    (const struct llama_context * ctx);
+    LLAMA_API int llama_n_ctx      (const struct llama_context * ctx);
+    LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
+    LLAMA_API int llama_n_embd     (const struct llama_context * ctx);
 
     LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
 
-    LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
-    LLAMA_API int llama_model_n_ctx  (const struct llama_model * model);
-    LLAMA_API int llama_model_n_embd (const struct llama_model * model);
+    LLAMA_API int llama_model_n_vocab    (const struct llama_model * model);
+    LLAMA_API int llama_model_n_ctx      (const struct llama_model * model);
+    LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
+    LLAMA_API int llama_model_n_embd     (const struct llama_model * model);
 
     // Get a string describing the model type
     LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);