]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add custom RoPE (#2054)
authorXiao-Yong Jin <redacted>
Sat, 15 Jul 2023 10:34:16 +0000 (06:34 -0400)
committerGitHub <redacted>
Sat, 15 Jul 2023 10:34:16 +0000 (13:34 +0300)
* Implement customizable RoPE

The original RoPE has pre-defined parameters

theta_i = 10000^(āˆ’2(iāˆ’1)/d), for i in [1, 2, ..., d/2]

Our customizable RoPE, ggml_rope_custom_inplace, uses

theta_i = scale * base^(āˆ’2(iāˆ’1)/d), for i in [1, 2, ..., d/2]

with the default matches the original

scale = 1.0
base = 10000

The new command line arguments
--rope-freq-base
--rope-freq-scale
set the two new RoPE parameter.

Recent researches show changing these two parameters extends the context limit with minimal loss.

1. Extending Context to 8K
   kaiokendev
   https://kaiokendev.github.io/til#extending-context-to-8k

2. Extending Context Window of Large Language Models via Positional Interpolation
   Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian
   https://arxiv.org/abs/2306.15595

3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
   https://www.reddit.com/user/bloc97
   https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/

For the bold, try adding the following command line parameters to your favorite model:
-c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5

* ggml-metal: fix custom rope

* common: fix argument names in help

* llama: increase MEM_REQ_EVAL for MODEL_3B

It avoids crashing for quantized weights on CPU.
Better ways to calculate the required buffer size would be better.

* llama: make MEM_REQ_EVAL depend on n_ctx

* server: use proper Content-Type in curl examples

Without the header Content-Type: application/json, curl will POST with
Content-Type: application/x-www-form-urlencoded

Though our simple server doesn't care, the httplib.h used has a limit
with CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192

With Content-Type: application/json, we can send large json data.

* style : minor fixes, mostly indentations

* ggml : fix asserts

---------

Co-authored-by: Georgi Gerganov <redacted>
12 files changed:
examples/common.cpp
examples/common.h
examples/main/main.cpp
examples/server/README.md
examples/server/chat.sh
examples/server/server.cpp
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
llama.cpp
llama.h

index 94875b0542dad0ddb8df9316db142ecfc48e8a5b..8705127cb11dc8b6fc09f970ddb378ce95efae95 100644 (file)
@@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_ctx = std::stoi(argv[i]);
+        } else if (arg == "--rope-freq-base") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.rope_freq_base = std::stof(argv[i]);
+        } else if (arg == "--rope-freq-scale") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.rope_freq_scale = std::stof(argv[i]);
         } else if (arg == "--memory-f32") {
             params.memory_f16 = false;
         } else if (arg == "--top-p") {
@@ -493,6 +505,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
     fprintf(stderr, "  --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
     fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
+    fprintf(stderr, "  --rope-freq-base N    RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
+    fprintf(stderr, "  --rope-freq-scale N   RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
     fprintf(stderr, "  --no-penalize-nl      do not penalize newline token\n");
     fprintf(stderr, "  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -573,6 +587,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     lparams.use_mlock    = params.use_mlock;
     lparams.logits_all   = params.perplexity;
     lparams.embedding    = params.embedding;
+    lparams.rope_freq_base  = params.rope_freq_base;
+    lparams.rope_freq_scale = params.rope_freq_scale;
 
     return lparams;
 }
index 6315df9613445697e63db3d492732a17c8d9c91a..f52fef6292bf2178deb49ba22b8079352cea786c 100644 (file)
@@ -32,6 +32,8 @@ struct gpt_params {
     int32_t main_gpu                        = 0;   // the GPU that is used for scratch and small tensors
     float   tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
     int32_t n_probs                         = 0;   // if greater than 0, output the probabilities of top n_probs tokens.
+    float   rope_freq_base                  = 10000.0f; // RoPE base frequency
+    float   rope_freq_scale                 = 1.0f;     // RoPE frequency scaling factor
 
     // sampling parameters
     std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
index 2248c245875b04b88ed84c17e0a6b50ba9825134..bcbcf12b0ac19d93a36ce13e215b095e147fa67e 100644 (file)
@@ -84,9 +84,17 @@ int main(int argc, char ** argv) {
         return 0;
     }
 
+    if (params.rope_freq_base != 10000.0) {
+        fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
+    }
+
+    if (params.rope_freq_scale != 1.0) {
+        fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
+    }
+
     if (params.n_ctx > 2048) {
-        fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
-                "expect poor results\n", __func__, params.n_ctx);
+        fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
+                " you are on your own\n", __func__, params.n_ctx);
     } else if (params.n_ctx < 8) {
         fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
         params.n_ctx = 8;
index ad9b6bb08184525c3b9532ba96a04b677de813e5..e5ca8269b9d5646e93744ed250d2404196e6e8ca 100644 (file)
@@ -66,6 +66,7 @@ Using [curl](https://curl.se/). On Windows `curl.exe` should be available in the
 ```sh
 curl --request POST \
     --url http://localhost:8080/completion \
+    --header "Content-Type: application/json" \
     --data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
 ```
 
index a89f8e908c403b22b656471d95261812fbcffc87..0143601214b15cf2f08ee3747a123855ec6aae7d 100644 (file)
@@ -32,6 +32,7 @@ tokenize() {
         --silent \
         --request POST \
         --url "${API_URL}/tokenize" \
+        --header "Content-Type: application/json" \
         --data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
     | jq '.tokens[]'
 }
@@ -64,6 +65,7 @@ chat_completion() {
         --no-buffer \
         --request POST \
         --url "${API_URL}/completion" \
+        --header "Content-Type: application/json" \
         --data-raw "${DATA}")
 
     printf "\n"
index 296c5d6468f1618acd1a748d886d1b1a6c376947..f442f2b56d368429fdf83d45fb91d7e34d971c4b 100644 (file)
@@ -608,6 +608,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     fprintf(stderr, "  -v, --verbose         verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -c N, --ctx-size N    size of the prompt context (default: %d)\n", params.n_ctx);
+    fprintf(stderr, "  --rope-freq-base N    RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
+    fprintf(stderr, "  --rope-freq-scale N   RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
     fprintf(stderr, "  -b N, --batch-size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n");
     fprintf(stderr, "                        not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -722,6 +724,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             params.n_ctx = std::stoi(argv[i]);
         }
+        else if (arg == "--rope-freq-base")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.rope_freq_base = std::stof(argv[i]);
+        }
+        else if (arg == "--rope-freq-scale")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.rope_freq_scale = std::stof(argv[i]);
+        }
         else if (arg == "--memory-f32" || arg == "--memory_f32")
         {
             params.memory_f16 = false;
index c795ee22784bda5b7b0dfd0e0631b16ba398f1fb..ee205bcdf773ca1666016ac63aae16a7aafdfb96 100644 (file)
@@ -881,28 +881,35 @@ void ggml_metal_graph_compute(
 
                             const int n_past = ((int32_t *)(src1->data))[0];
 
+                            float freq_base;
+                            float freq_scale;
+                            memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
+                            memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+
                             [encoder setComputePipelineState:ctx->pipeline_rope];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00   length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&ne01   length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&ne02   length:sizeof( int64_t) atIndex:4];
-                            [encoder setBytes:&ne03   length:sizeof( int64_t) atIndex:5];
-                            [encoder setBytes:&nb00   length:sizeof(uint64_t) atIndex:6];
-                            [encoder setBytes:&nb01   length:sizeof(uint64_t) atIndex:7];
-                            [encoder setBytes:&nb02   length:sizeof(uint64_t) atIndex:8];
-                            [encoder setBytes:&nb03   length:sizeof(uint64_t) atIndex:9];
-                            [encoder setBytes:&ne0    length:sizeof( int64_t) atIndex:10];
-                            [encoder setBytes:&ne1    length:sizeof( int64_t) atIndex:11];
-                            [encoder setBytes:&ne2    length:sizeof( int64_t) atIndex:12];
-                            [encoder setBytes:&ne3    length:sizeof( int64_t) atIndex:13];
-                            [encoder setBytes:&nb0    length:sizeof(uint64_t) atIndex:14];
-                            [encoder setBytes:&nb1    length:sizeof(uint64_t) atIndex:15];
-                            [encoder setBytes:&nb2    length:sizeof(uint64_t) atIndex:16];
-                            [encoder setBytes:&nb3    length:sizeof(uint64_t) atIndex:17];
-                            [encoder setBytes:&n_past length:sizeof(     int) atIndex:18];
-                            [encoder setBytes:&n_dims length:sizeof(     int) atIndex:19];
-                            [encoder setBytes:&mode   length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
+                            [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
+                            [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
+                            [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
+                            [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
+                            [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
+                            [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
+                            [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
+                            [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
+                            [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
+                            [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
+                            [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
+                            [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+                            [encoder setBytes:&n_past  length:sizeof(     int) atIndex:18];
+                            [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:19];
+                            [encoder setBytes:&mode    length:sizeof(     int) atIndex:20];
+                            [encoder setBytes:&freq_base  length:sizeof(float) atIndex:21];
+                            [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
index f094a1d407f08b21f6493c105218e891a3cc3f37..9f9a4fbd74446e2c203fe2f8c65653ee6f5360f7 100644 (file)
@@ -656,17 +656,19 @@ kernel void kernel_rope(
         constant       int & n_past,
         constant       int & n_dims,
         constant       int & mode,
+        constant     float & freq_base,
+        constant     float & freq_scale,
         uint3 tpig[[thread_position_in_grid]]) {
     const int64_t i3 = tpig[2];
     const int64_t i2 = tpig[1];
     const int64_t i1 = tpig[0];
 
     const bool is_neox = mode & 2;
-    const float theta_scale = pow(10000.0, -2.0f/n_dims);
+    const float theta_scale = pow(freq_base, -2.0f/n_dims);
 
     const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
 
-    float theta = (float)p;
+    float theta = freq_scale * (float)p;
 
     if (!is_neox) {
         for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
diff --git a/ggml.c b/ggml.c
index 3ea8ba6ec92f65c098bd6f2ec65c43713ad17edd..5ce1da0e9df4dc503903efacf926b35497ffee70 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -6956,6 +6956,8 @@ struct ggml_tensor * ggml_rope_impl(
         int                   n_past,
         int                   n_dims,
         int                   mode,
+        float                 freq_base,
+        float                 freq_scale,
         int                   n_ctx,
         bool                  inplace) {
     GGML_ASSERT(n_past >= 0);
@@ -6969,12 +6971,14 @@ struct ggml_tensor * ggml_rope_impl(
 
     ggml_scratch_save(ctx);
 
-    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
 
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
     ((int32_t *) b->data)[3] = n_ctx;
+    memcpy((int32_t *) b->data + 4, &freq_base,  sizeof(float));
+    memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
 
     ggml_scratch_load(ctx);
 
@@ -6993,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -7003,7 +7007,19 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
+}
+
+struct ggml_tensor * ggml_rope_custom_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_dims,
+        int                   mode,
+        float                 freq_base,
+        float                 freq_scale,
+        int                   n_ctx) {
+    return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
 }
 
 // ggml_rope_back
@@ -12074,16 +12090,21 @@ static void ggml_compute_forward_rope_f32(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 4);
+    GGML_ASSERT(ggml_nelements(src1) == 6);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
+    float freq_base;
+    float freq_scale;
+
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
     const int n_ctx  = ((int32_t *) src1->data)[3];
+    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -12112,7 +12133,7 @@ static void ggml_compute_forward_rope_f32(
     // row index used to determine which thread to use
     int ir = 0;
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
@@ -12124,7 +12145,7 @@ static void ggml_compute_forward_rope_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = (float)p;
+                float theta = freq_scale * (float)p;
 
                 if (is_glm) {
                     theta = MIN(p, n_ctx - 2);
@@ -12201,16 +12222,21 @@ static void ggml_compute_forward_rope_f16(
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(ggml_nelements(src1) == 4);
+    GGML_ASSERT(ggml_nelements(src1) == 6);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
+    float freq_base;
+    float freq_scale;
+
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
     const int n_ctx  = ((int32_t *) src1->data)[3];
+    memcpy(&freq_base,  (int32_t *) src1->data + 4, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
 
     assert(n_past >= 0);
 
@@ -12239,7 +12265,7 @@ static void ggml_compute_forward_rope_f16(
     // row index used to determine which thread to use
     int ir = 0;
 
-    const float theta_scale = powf(10000.0, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
@@ -12251,7 +12277,7 @@ static void ggml_compute_forward_rope_f16(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = (float)p;
+                float theta = freq_scale * (float)p;
 
                 if (is_glm) {
                     theta = MIN(p, n_ctx - 2);
@@ -12312,7 +12338,7 @@ static void ggml_compute_forward_rope_f16(
                             const float x0 = GGML_FP16_TO_FP32(src[0]);
                             const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
 
-                            dst_data[0]     = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
                             dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                         }
                     }
@@ -15710,7 +15736,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                 // necessary for llama
                 if (src0->grad) {
                     assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 4);
+                    assert(ggml_nelements(src1) == 6);
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];
@@ -15731,7 +15757,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 if (src0->grad) {
                     assert(src1->type == GGML_TYPE_I32);
-                    assert(ggml_nelements(src1) == 4);
+                    assert(ggml_nelements(src1) == 3);
                     const int n_past = ((int32_t *) src1->data)[0];
                     const int n_dims = ((int32_t *) src1->data)[1];
                     const int mode   = ((int32_t *) src1->data)[2];
diff --git a/ggml.h b/ggml.h
index b88c35bae72670d52ca6ce4ae2ca2dfe5a69e8f8..24856a255c9a6d57683b024c979f159e7f78f64e 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1121,6 +1121,17 @@ extern "C" {
             int                   mode,
             int                   n_ctx);
 
+    // custom RoPE, in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode,
+            float                 freq_base,
+            float                 freq_scale,
+            int                   n_ctx);
+
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
     GGML_API struct ggml_tensor * ggml_rope_back(
index b0cd9417c07d8eac4191e27219d9c60928966f45..27e1ee9646448638c172ca9162959bc8b6ca6ebb 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -101,14 +101,15 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
 // memory sizes
 //
 
-static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
+static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
 {
     static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,    256ull * MB },
-        { MODEL_7B,    512ull * MB },
-        { MODEL_13B,   512ull * MB },
-        { MODEL_30B,   512ull * MB },
-        { MODEL_65B,  1024ull * MB },
+        /* empirical scaling, still a guess */
+        { MODEL_3B,   ((size_t) n_ctx / 16ull + 128ull) * MB },
+        { MODEL_7B,   ((size_t) n_ctx / 16ull + 256ull) * MB },
+        { MODEL_13B,  ((size_t) n_ctx / 12ull + 256ull) * MB },
+        { MODEL_30B,  ((size_t) n_ctx / 10ull + 256ull) * MB },
+        { MODEL_65B,  ((size_t) n_ctx /  8ull + 512ull) * MB },
     };
     return k_sizes;
 }
@@ -140,14 +141,14 @@ static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
 
 // this is mostly needed for temporary mul_mat buffers to dequantize the data
 // not actually needed if BLAS is disabled
-static const std::map<e_model, size_t> & MEM_REQ_EVAL()
+static const std::map<e_model, size_t> & MEM_REQ_EVAL(int n_ctx)
 {
     static std::map<e_model, size_t> k_sizes = {
-        { MODEL_3B,   512ull * MB },
-        { MODEL_7B,   768ull * MB },
-        { MODEL_13B, 1024ull * MB },
-        { MODEL_30B, 1280ull * MB },
-        { MODEL_65B, 1536ull * MB },
+        { MODEL_3B,  ((size_t) n_ctx / 256ull +  512ull) * MB },
+        { MODEL_7B,  ((size_t) n_ctx / 256ull +  768ull) * MB },
+        { MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB },
+        { MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB },
+        { MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB },
     };
     return k_sizes;
 }
@@ -189,6 +190,10 @@ struct llama_hparams {
     uint32_t n_head  = 32;
     uint32_t n_layer = 32;
     uint32_t n_rot   = 64;
+
+    float rope_freq_base  = 10000.0f;
+    float rope_freq_scale = 1.0f;
+
     enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
 
     bool operator!=(const llama_hparams & other) const {
@@ -647,7 +652,7 @@ struct llama_model_loader {
         *ctx_size_p = *mmapped_size_p = 0;
         for (const llama_load_tensor & lt : tensors_map.tensors) {
             *ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE;
-            *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size;
+            *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size + 16;
         }
     }
 
@@ -843,6 +848,8 @@ struct llama_context_params llama_context_default_params() {
         /*.gpu_layers                  =*/ 0,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ {0},
+        /*.rope_freq_base              =*/ 10000.0f,
+        /*.rope_freq_scale             =*/ 1.0f,
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
         /*.low_vram                    =*/ false,
@@ -966,6 +973,8 @@ static void llama_model_load_internal(
         int n_gpu_layers,
         int main_gpu,
         const float * tensor_split,
+        float rope_freq_base,
+        float rope_freq_scale,
         bool low_vram,
         ggml_type memory_type,
         bool use_mmap,
@@ -1000,22 +1009,27 @@ static void llama_model_load_internal(
         }
 
         hparams.n_ctx = n_ctx;
+
+        hparams.rope_freq_base  = rope_freq_base;
+        hparams.rope_freq_scale = rope_freq_scale;
     }
 
     const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
 
     {
-        fprintf(stderr, "%s: format     = %s\n",  __func__, llama_file_version_name(file_version));
-        fprintf(stderr, "%s: n_vocab    = %u\n",  __func__, hparams.n_vocab);
-        fprintf(stderr, "%s: n_ctx      = %u\n",  __func__, hparams.n_ctx);
-        fprintf(stderr, "%s: n_embd     = %u\n",  __func__, hparams.n_embd);
-        fprintf(stderr, "%s: n_mult     = %u\n",  __func__, hparams.n_mult);
-        fprintf(stderr, "%s: n_head     = %u\n",  __func__, hparams.n_head);
-        fprintf(stderr, "%s: n_layer    = %u\n",  __func__, hparams.n_layer);
-        fprintf(stderr, "%s: n_rot      = %u\n",  __func__, hparams.n_rot);
+        fprintf(stderr, "%s: format     = %s\n",   __func__, llama_file_version_name(file_version));
+        fprintf(stderr, "%s: n_vocab    = %u\n",   __func__, hparams.n_vocab);
+        fprintf(stderr, "%s: n_ctx      = %u\n",   __func__, hparams.n_ctx);
+        fprintf(stderr, "%s: n_embd     = %u\n",   __func__, hparams.n_embd);
+        fprintf(stderr, "%s: n_mult     = %u\n",   __func__, hparams.n_mult);
+        fprintf(stderr, "%s: n_head     = %u\n",   __func__, hparams.n_head);
+        fprintf(stderr, "%s: n_layer    = %u\n",   __func__, hparams.n_layer);
+        fprintf(stderr, "%s: n_rot      = %u\n",   __func__, hparams.n_rot);
+        fprintf(stderr, "%s: freq_base  = %.1f\n", __func__, hparams.rope_freq_base);
+        fprintf(stderr, "%s: freq_scale = %g\n",   __func__, hparams.rope_freq_scale);
         fprintf(stderr, "%s: ftype      = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
-        fprintf(stderr, "%s: n_ff       = %u\n",  __func__, n_ff);
-        fprintf(stderr, "%s: model size = %s\n",  __func__, llama_model_type_name(model.type));
+        fprintf(stderr, "%s: n_ff       = %u\n",   __func__, n_ff);
+        fprintf(stderr, "%s: model size = %s\n",   __func__, llama_model_type_name(model.type));
     }
 
     if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
@@ -1164,9 +1178,9 @@ static void llama_model_load_internal(
         const size_t mem_required =
             ctx_size +
             mmapped_size - vram_weights + // weights in VRAM not in memory
-            MEM_REQ_SCRATCH0().at(model.type) +
+            MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
             MEM_REQ_SCRATCH1().at(model.type) +
-            MEM_REQ_EVAL().at    (model.type);
+            MEM_REQ_EVAL(hparams.n_ctx).at(model.type);
 
         // this is the memory required by one llama_state
         const size_t mem_required_state =
@@ -1270,6 +1284,8 @@ static bool llama_model_load(
         int n_gpu_layers,
         int main_gpu,
         float * tensor_split,
+        float rope_freq_base,
+        float rope_freq_scale,
         bool low_vram,
         ggml_type memory_type,
         bool use_mmap,
@@ -1278,7 +1294,7 @@ static bool llama_model_load(
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, low_vram, memory_type,
+        llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
                                   use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
         return true;
     } catch (const std::exception & err) {
@@ -1330,6 +1346,9 @@ static bool llama_eval_internal(
     const int n_rot        = hparams.n_embd/hparams.n_head;
     const int n_gpu_layers = model.n_gpu_layers;
 
+    const float freq_base  = hparams.rope_freq_base;
+    const float freq_scale = hparams.rope_freq_scale;
+
     auto & mem_per_token = lctx.mem_per_token;
     auto & buf_compute   = lctx.buf_compute;
 
@@ -1427,11 +1446,11 @@ static bool llama_eval_internal(
             offload_func_kq(tmpq);
             ggml_set_name(tmpq, "tmpq");
 
-            struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
             offload_func_kq(Kcur);
             ggml_set_name(Kcur, "Kcur");
 
-            struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+            struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
             offload_func_kq(Qcur);
             ggml_set_name(Qcur, "Qcur");
 
@@ -2674,8 +2693,9 @@ struct llama_model * llama_load_model_from_file(
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
     if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers,
-                params.main_gpu, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock,
-                params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
+                params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
+                memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
+                params.progress_callback_user_data)) {
         delete model;
         fprintf(stderr, "%s: failed to load model\n", __func__);
         return nullptr;
@@ -2750,9 +2770,9 @@ struct llama_context * llama_new_context_with_model(
             ctx->embedding.resize(hparams.n_embd);
         }
 
-        ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type));
+        ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type));
 
-        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type));
+        ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
         ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
     }
 
diff --git a/llama.h b/llama.h
index e7c60f40c60463429e7e76cfb168e1449c6c6526..e744584f23fb3480268b2d90e88314c83ac09188 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -89,6 +89,11 @@ extern "C" {
         int32_t  n_gpu_layers;                 // number of layers to store in VRAM
         int32_t  main_gpu;                     // the GPU that is used for scratch and small tensors
         float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
+
+        // ref: https://github.com/ggerganov/llama.cpp/pull/2054
+        float    rope_freq_base;  // RoPE base frequency
+        float    rope_freq_scale; // RoPE frequency scaling factor
+
         // called with a progress value between 0 and 1, pass NULL to disable
         llama_progress_callback progress_callback;
         // context pointer passed to the progress callback