]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : implement YaRN RoPE scaling (#2268)
authorcebtenzzre <redacted>
Wed, 1 Nov 2023 22:04:33 +0000 (18:04 -0400)
committerGitHub <redacted>
Wed, 1 Nov 2023 22:04:33 +0000 (18:04 -0400)
Co-authored-by: cebtenzzre <redacted>
Co-authored-by: Jeffrey Quesnelle <redacted>
15 files changed:
common/common.cpp
common/common.h
convert-baichuan-hf-to-gguf.py
convert.py
examples/finetune/finetune.cpp
examples/server/server.cpp
examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
gguf-py/gguf/gguf.py
llama.cpp
llama.h

index 7a48e9d11e85939e992386049bb473de709bcac5..b182ffaaef48ec204c6f6fdb07f3aaa4920e10e6 100644 (file)
@@ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.rope_freq_scale = std::stof(argv[i]);
+        } else if (arg == "--rope-scaling") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::string value(argv[i]);
+            /**/ if (value == "none")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
+            else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
+            else if (value == "yarn")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
+            else { invalid_param = true; break; }
         } else if (arg == "--rope-scale") {
             if (++i >= argc) {
                 invalid_param = true;
                 break;
             }
             params.rope_freq_scale = 1.0f/std::stof(argv[i]);
+        } else if (arg == "--yarn-orig-ctx") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_orig_ctx = std::stoi(argv[i]);
+        } else if (arg == "--yarn-ext-factor") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_ext_factor = std::stof(argv[i]);
+        } else if (arg == "--yarn-attn-factor") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_attn_factor = std::stof(argv[i]);
+        } else if (arg == "--yarn-beta-fast") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_beta_fast = std::stof(argv[i]);
+        } else if (arg == "--yarn-beta-slow") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_beta_slow = std::stof(argv[i]);
         } else if (arg == "--memory-f32") {
             params.memory_f16 = false;
         } else if (arg == "--top-p") {
@@ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --cfg-negative-prompt-file FNAME\n");
     printf("                        negative prompt file to use for guidance. (default: empty)\n");
     printf("  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
-    printf("  --rope-scale N        RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
+    printf("  --rope-scaling {none,linear,yarn}\n");
+    printf("                        RoPE frequency scaling method, defaults to linear unless specified by the model\n");
+    printf("  --rope-scale N        RoPE context scaling factor, expands context by a factor of N\n");
     printf("  --rope-freq-base N    RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
-    printf("  --rope-freq-scale N   RoPE frequency linear scaling factor (default: loaded from model)\n");
+    printf("  --rope-freq-scale N   RoPE frequency scaling factor, expands context by a factor of 1/N\n");
+    printf("  --yarn-orig-ctx N     YaRN: original context size of model (default: 0 = model training context size)\n");
+    printf("  --yarn-ext-factor N   YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
+    printf("  --yarn-attn-factor N  YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
+    printf("  --yarn-beta-slow N    YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
+    printf("  --yarn-beta-fast N    YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
     printf("  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
     printf("  --no-penalize-nl      do not penalize newline token\n");
     printf("  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -826,17 +873,23 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
 struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
     auto cparams = llama_context_default_params();
 
-    cparams.n_ctx           = params.n_ctx;
-    cparams.n_batch         = params.n_batch;
-    cparams.n_threads       = params.n_threads;
-    cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
-    cparams.mul_mat_q       = params.mul_mat_q;
-    cparams.seed            = params.seed;
-    cparams.f16_kv          = params.memory_f16;
-    cparams.logits_all      = params.logits_all;
-    cparams.embedding       = params.embedding;
-    cparams.rope_freq_base  = params.rope_freq_base;
-    cparams.rope_freq_scale = params.rope_freq_scale;
+    cparams.n_ctx             = params.n_ctx;
+    cparams.n_batch           = params.n_batch;
+    cparams.n_threads         = params.n_threads;
+    cparams.n_threads_batch   = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+    cparams.mul_mat_q         = params.mul_mat_q;
+    cparams.seed              = params.seed;
+    cparams.f16_kv            = params.memory_f16;
+    cparams.logits_all        = params.logits_all;
+    cparams.embedding         = params.embedding;
+    cparams.rope_scaling_type = params.rope_scaling_type;
+    cparams.rope_freq_base    = params.rope_freq_base;
+    cparams.rope_freq_scale   = params.rope_freq_scale;
+    cparams.yarn_ext_factor   = params.yarn_ext_factor;
+    cparams.yarn_attn_factor  = params.yarn_attn_factor;
+    cparams.yarn_beta_fast    = params.yarn_beta_fast;
+    cparams.yarn_beta_slow    = params.yarn_beta_slow;
+    cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
 
     return cparams;
 }
index 343b272177c7ec601904ee46882d10f39e69e862..7be69f925bc2bec959036f59e330453d587009ea 100644 (file)
@@ -9,6 +9,7 @@
 #define LOG_NO_FILE_LINE_FUNCTION
 #include "log.h"
 
+#include <cmath>
 #include <string>
 #include <vector>
 #include <random>
@@ -54,6 +55,12 @@ struct gpt_params {
     int32_t n_beams                         = 0;    // if non-zero then use beam search of given width.
     float   rope_freq_base                  = 0.0f; // RoPE base frequency
     float   rope_freq_scale                 = 0.0f; // RoPE frequency scaling factor
+    float   yarn_ext_factor                 = NAN;  // YaRN extrapolation mix factor
+    float   yarn_attn_factor                = 1.0f; // YaRN magnitude scaling factor
+    float   yarn_beta_fast                  = 32.0f;// YaRN low correction dim
+    float   yarn_beta_slow                  = 1.0f; // YaRN high correction dim
+    int32_t yarn_orig_ctx                   = 0;    // YaRN original context length
+    int8_t  rope_scaling_type               = LLAMA_ROPE_SCALING_UNSPECIFIED;
 
     // // sampling parameters
     struct llama_sampling_params sparams;
index 5ee99be73134e6ef807c322c8e64e7f9ef263dcd..67ccbe99f132af8bc4a7179d21ccbfa4112f5dc2 100755 (executable)
@@ -163,7 +163,8 @@ gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
 if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
     if "type" in hparams["rope_scaling"]:
         if hparams["rope_scaling"]["type"] == "linear":
-            gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
+            gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+            gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
 
 
 # TOKENIZATION
index bfbfab283f6ae775bba3b13caeb0c9d6fa1edf26..9110f15806c6bc0962014f1b6919197eb013233c 100755 (executable)
@@ -151,8 +151,11 @@ class Params:
     n_head_kv:  int
     f_norm_eps: float
 
+    rope_scaling_type: gguf.RopeScalingType | None = None
     f_rope_freq_base: float | None = None
     f_rope_scale: float | None = None
+    n_orig_ctx: int | None = None
+    rope_finetuned: bool | None = None
 
     ftype: GGMLFileType | None = None
 
@@ -198,20 +201,20 @@ class Params:
     def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
         config = json.load(open(config_path))
 
-        n_vocab          = config["vocab_size"]
-        n_embd           = config["hidden_size"]
-        n_layer          = config["num_hidden_layers"]
-        n_ff             = config["intermediate_size"]
-        n_head           = config["num_attention_heads"]
-        n_head_kv        = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
-        f_norm_eps       = config["rms_norm_eps"]
-        f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
-
+        rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
         rope_scaling = config.get("rope_scaling")
-        if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
-            f_rope_scale = config["rope_scaling"].get("factor")
-        else:
-            f_rope_scale = None
+
+        if rope_scaling is not None and (typ := rope_scaling.get("type")):
+            rope_factor = rope_scaling.get("factor")
+            f_rope_scale = rope_factor
+            if typ == "linear":
+                rope_scaling_type = gguf.RopeScalingType.LINEAR
+            elif typ == "yarn":
+                rope_scaling_type = gguf.RopeScalingType.YARN
+                n_orig_ctx = rope_scaling['original_max_position_embeddings']
+                rope_finetuned = rope_scaling['finetuned']
+            else:
+                raise NotImplementedError(f'Unknown rope scaling type: {typ}')
 
         if "max_sequence_length" in config:
             n_ctx = config["max_sequence_length"]
@@ -222,16 +225,19 @@ class Params:
                             "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
 
         return Params(
-            n_vocab          = n_vocab,
-            n_embd           = n_embd,
-            n_layer          = n_layer,
-            n_ctx            = n_ctx,
-            n_ff             = n_ff,
-            n_head           = n_head,
-            n_head_kv        = n_head_kv,
-            f_norm_eps       = f_norm_eps,
-            f_rope_freq_base = f_rope_freq_base,
-            f_rope_scale     = f_rope_scale,
+            n_vocab           = config["vocab_size"],
+            n_embd            = config["hidden_size"],
+            n_layer           = config["num_hidden_layers"],
+            n_ctx             = n_ctx,
+            n_ff              = config["intermediate_size"],
+            n_head            = (n_head := config["num_attention_heads"]),
+            n_head_kv         = config.get("num_key_value_heads", n_head),
+            f_norm_eps        = config["rms_norm_eps"],
+            f_rope_freq_base  = config.get("rope_theta"),
+            rope_scaling_type = rope_scaling_type,
+            f_rope_scale      = f_rope_scale,
+            n_orig_ctx        = n_orig_ctx,
+            rope_finetuned    = rope_finetuned,
         )
 
     # LLaMA v2 70B params.json
@@ -240,17 +246,8 @@ class Params:
     def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
         config = json.load(open(config_path))
 
-        n_vocab          = config["vocab_size"] if "vocab_size" in config else -1
-        n_embd           = config["dim"]
-        n_layer          = config["n_layers"]
-        n_ff             = -1
-        n_head           = config["n_heads"]
-        n_head_kv        = config["n_kv_heads"] if "n_kv_heads" in config else n_head
-        f_norm_eps       = config["norm_eps"]
-        f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
-
         # hack to determine LLaMA v1 vs v2 vs CodeLlama
-        if f_rope_freq_base == 1000000:
+        if config.get("rope_theta") == 1000000:
             # CodeLlama
             n_ctx = 16384
         elif config["norm_eps"] == 1e-05:
@@ -260,22 +257,16 @@ class Params:
             # LLaMA v1
             n_ctx = 2048
 
-        if n_vocab == -1:
-            n_vocab = model["tok_embeddings.weight"].shape[0]
-
-        if n_ff == -1:
-            n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
-
         return Params(
-            n_vocab          = n_vocab,
-            n_embd           = n_embd,
-            n_layer          = n_layer,
+            n_vocab          = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
+            n_embd           = config["dim"],
+            n_layer          = config["n_layers"],
             n_ctx            = n_ctx,
-            n_ff             = n_ff,
-            n_head           = n_head,
-            n_head_kv        = n_head_kv,
-            f_norm_eps       = f_norm_eps,
-            f_rope_freq_base = f_rope_freq_base,
+            n_ff             = model["layers.0.feed_forward.w1.weight"].shape[0],
+            n_head           = (n_head := config["n_heads"]),
+            n_head_kv        = config.get("n_kv_heads", n_head),
+            f_norm_eps       = config["norm_eps"],
+            f_rope_freq_base = config.get("rope_theta"),
         )
 
     @staticmethod
@@ -831,8 +822,16 @@ class OutputFile:
         if params.f_rope_freq_base is not None:
             self.gguf.add_rope_freq_base(params.f_rope_freq_base)
 
-        if params.f_rope_scale is not None:
-            self.gguf.add_rope_scale_linear(params.f_rope_scale)
+        if params.rope_scaling_type:
+            assert params.f_rope_scale is not None
+            self.gguf.add_rope_scaling_type(params.rope_scaling_type)
+            self.gguf.add_rope_scaling_factor(params.f_rope_scale)
+
+        if params.n_orig_ctx is not None:
+            self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
+
+        if params.rope_finetuned is not None:
+            self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
 
         if params.ftype is not None:
             self.gguf.add_file_type(params.ftype)
index 60c7faa797028af87ac3056de8bc1bcd490f2763..649a3b7c1941e5fb4ab4a40752764b0bcc211b11 100644 (file)
@@ -642,8 +642,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
         const int rope_mode = 0;
 
         return ggml_rope_custom(ctx,
-            t, KQ_pos, n_rot, rope_mode, n_ctx,
-            rope_freq_base, rope_freq_scale);
+            t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
+            rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
+        );
     };
 
     set_name(tokens_input, "tokens_input");
index 47ae0d55856cf8d8784eab3df65f807923987757..84b04d5a0493a67a828b27f0dc226f5959dfbf94 100644 (file)
@@ -1755,12 +1755,18 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
     printf("options:\n");
     printf("  -h, --help                show this help message and exit\n");
     printf("  -v, --verbose             verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
-    printf("  -t N,  --threads N        number of threads to use during computation (default: %d)\n", params.n_threads);
+    printf("  -t N, --threads N         number of threads to use during computation (default: %d)\n", params.n_threads);
     printf("  -tb N, --threads-batch N  number of threads to use during batch and prompt processing (default: same as --threads)\n");
-    printf("  -c N,  --ctx-size N       size of the prompt context (default: %d)\n", params.n_ctx);
+    printf("  -c N, --ctx-size N        size of the prompt context (default: %d)\n", params.n_ctx);
+    printf("  --rope-scaling {none,linear,yarn}\n");
+    printf("                            RoPE frequency scaling method, defaults to linear unless specified by the model\n");
     printf("  --rope-freq-base N        RoPE base frequency (default: loaded from model)\n");
-    printf("  --rope-freq-scale N       RoPE frequency scaling factor (default: loaded from model)\n");
-    printf("  -b N,  --batch-size N     batch size for prompt processing (default: %d)\n", params.n_batch);
+    printf("  --rope-freq-scale N       RoPE frequency scaling factor, expands context by a factor of 1/N\n");
+    printf("  --yarn-ext-factor N       YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
+    printf("  --yarn-attn-factor N      YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
+    printf("  --yarn-beta-slow N        YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
+    printf("  --yarn-beta-fast N        YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
+    printf("  -b N, --batch-size N      batch size for prompt processing (default: %d)\n", params.n_batch);
     printf("  --memory-f32              use f32 instead of f16 for memory key+value (default: disabled)\n");
     printf("                            not recommended: doubles context memory required and no measurable increase in quality\n");
     if (llama_mlock_supported())
@@ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             params.n_ctx = std::stoi(argv[i]);
         }
+        else if (arg == "--rope-scaling")
+        {
+            if (++i >= argc)
+            {
+                invalid_param = true;
+                break;
+            }
+            std::string value(argv[i]);
+            /**/ if (value == "none")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
+            else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
+            else if (value == "yarn")   { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
+            else { invalid_param = true; break; }
+        }
         else if (arg == "--rope-freq-base")
         {
             if (++i >= argc)
@@ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
             }
             params.rope_freq_scale = std::stof(argv[i]);
         }
+        else if (arg == "--yarn-ext-factor")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_ext_factor = std::stof(argv[i]);
+        }
+        else if (arg == "--yarn-attn-factor")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_attn_factor = std::stof(argv[i]);
+        }
+        else if (arg == "--yarn-beta-fast")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_beta_fast = std::stof(argv[i]);
+        }
+        else if (arg == "--yarn-beta-slow")
+        {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.yarn_beta_slow = std::stof(argv[i]);
+        }
         else if (arg == "--memory-f32" || arg == "--memory_f32")
         {
             params.memory_f16 = false;
index 1ce6cef29cfd06952d6e2ff32c850ba365dae4dd..2a257e63215e3cd7a6f2adbcddc7dd18b8412001 100644 (file)
@@ -349,9 +349,9 @@ static struct ggml_tensor * llama_build_train_graphs(
         // not capturing these, to silcence warnings
         const int rope_mode = 0;
 
-        return ggml_rope_custom(ctx,
-            t, KQ_pos, n_rot, rope_mode, n_ctx,
-            rope_freq_base, rope_freq_scale);
+        return ggml_rope_custom(
+            ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
+        );
     };
 
     set_name(tokens_input, "tokens_input");
index 4e6e7cd94892b1d1f7197d49c1254ac206e3127a..12ee10e3d9bdcdbc59ec7c14cfa842da89ee2ccd 100644 (file)
@@ -4493,11 +4493,41 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
-// rope == RoPE == rotary positional embedding
+static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+struct rope_corr_dims {
+    float v[4];
+};
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static __device__ void rope_yarn(
+    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+// rope == RoPE == rotary positional embedding
 template<typename T, bool has_pos>
-static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                            const int p_delta_rows, const float theta_scale) {
+static __global__ void rope(
+    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (col >= ncols) {
@@ -4509,10 +4539,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
     const int i2 = row/p_delta_rows;
 
     const int p = has_pos ? pos[i2] : 0;
-    const float p0 = p*freq_scale;
-    const float theta = p0*powf(theta_scale, col/2);
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
+    const float theta_base = p*powf(freq_base, -col/ncols);
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + 1];
@@ -4522,8 +4552,10 @@ static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t
 }
 
 template<typename T, bool has_pos>
-static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                                 const int p_delta_rows, const float theta_scale) {
+static __global__ void rope_neox(
+    const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims
+) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (col >= ncols) {
@@ -4534,11 +4566,14 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
     const int i = row*ncols + col/2;
     const int i2 = row/p_delta_rows;
 
+    // simplified from `(row * ncols + col) * (-1 / ncols)`
+    const float cur_rot = -col/ncols - row;
+
     const int p = has_pos ? pos[i2] : 0;
-    const float p0 = p*freq_scale;
-    const float theta = p0*powf(theta_scale, col/2);
-    const float sin_theta = sinf(theta);
-    const float cos_theta = cosf(theta);
+    const float theta_base = p*powf(freq_base, cur_rot);
+
+    float cos_theta, sin_theta;
+    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + ncols/2];
@@ -4547,8 +4582,10 @@ static __global__ void rope_neox(const T * x, T * dst, const int ncols, const in
     dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
 }
 
-static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
-                                    const int p_delta_rows, const float theta_scale, const int n_ctx) {
+static __global__ void rope_glm_f32(
+    const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
+    int n_ctx
+) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int half_n_dims = ncols/4;
 
@@ -4560,7 +4597,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     const int i = row*ncols + col;
     const int i2 = row/p_delta_rows;
 
-    const float col_theta_scale = powf(theta_scale, col);
+    const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
      // FIXME: this is likely wrong
     const int p = pos != nullptr ? pos[i2] : 0;
 
@@ -5584,40 +5621,54 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
 }
 
 template<typename T>
-static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                          const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+static void rope_cuda(
+    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     if (pos == nullptr) {
-        rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope<T, false><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     } else {
-        rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope<T, true><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     }
 }
 
 template<typename T>
-static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                          const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
+static void rope_neox_cuda(
+    const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nrows, num_blocks_x, 1);
     if (pos == nullptr) {
-        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     } else {
-        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
+        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
+            x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
+        );
     }
 }
 
-static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
-                              const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
+static void rope_glm_f32_cuda(
+    const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, int n_ctx, cudaStream_t stream
+) {
     GGML_ASSERT(ncols % 4 == 0);
     const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
     const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
     const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale, n_ctx);
+    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
 }
 
 static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@@ -6477,17 +6528,20 @@ inline void ggml_cuda_op_rope(
     const int64_t ne2 = dst->ne[2];
     const int64_t nrows = ggml_nrows(src0);
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    // RoPE alteration for extended context
-
-    float freq_base, freq_scale;
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    //const int n_past      = ((int32_t *) dst->op_params)[0];
+    const int n_dims      = ((int32_t *) dst->op_params)[1];
+    const int mode        = ((int32_t *) dst->op_params)[2];
+    const int n_ctx       = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx  = ((int32_t *) dst->op_params)[4];
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    // RoPE alteration for extended context
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
     const int32_t * pos = nullptr;
     if ((mode & 1) == 0) {
@@ -6499,24 +6553,39 @@ inline void ggml_cuda_op_rope(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    rope_corr_dims corr_dims;
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
+
     // compute
     if (is_glm) {
         GGML_ASSERT(false);
-        rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
+        rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
     } else if (is_neox) {
         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
         if (src0->type == GGML_TYPE_F32) {
-            rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_neox_cuda(
+                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_neox_cuda(
+                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else {
             GGML_ASSERT(false);
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
-            rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_cuda(
+                (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
+            rope_cuda(
+                (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, main_stream
+            );
         } else {
             GGML_ASSERT(false);
         }
index 1f034150788e266e6f6a824b725ca6975e3d1ee2..611d5e173681eb5a6e8d967beb84488ee933cb4d 100644 (file)
@@ -1400,14 +1400,18 @@ void ggml_metal_graph_compute(
 
                             const int nth = MIN(1024, ne00);
 
-                            const int n_past = ((int32_t *) dst->op_params)[0];
-                            const int n_dims = ((int32_t *) dst->op_params)[1];
-                            const int mode   = ((int32_t *) dst->op_params)[2];
-
-                            float freq_base;
-                            float freq_scale;
-                            memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-                            memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+                            const int n_past     = ((int32_t *) dst->op_params)[0];
+                            const int n_dims     = ((int32_t *) dst->op_params)[1];
+                            const int mode       = ((int32_t *) dst->op_params)[2];
+                            const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
+
+                            float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+                            memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+                            memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+                            memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+                            memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+                            memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+                            memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
                             switch (src0->type) {
                                 case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
@@ -1439,6 +1443,10 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&mode    length:sizeof(     int) atIndex:21];
                             [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22];
                             [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
+                            [encoder setBytes:&ext_factor  length:sizeof(float) atIndex:24];
+                            [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
+                            [encoder setBytes:&beta_fast   length:sizeof(float) atIndex:26];
+                            [encoder setBytes:&beta_slow   length:sizeof(float) atIndex:27];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
index f3152778ae48c34d472d2150761a0773b0cbf2ce..471d7d390f8138bc657e9a342a28407380659de7 100644 (file)
@@ -1061,6 +1061,45 @@ kernel void kernel_alibi_f32(
     }
 }
 
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
+    return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
 typedef void (rope_t)(
         device const    void * src0,
         device const int32_t * src1,
@@ -1116,6 +1155,10 @@ kernel void kernel_rope(
         constant         int & mode,
         constant       float & freq_base,
         constant       float & freq_scale,
+        constant       float & ext_factor,
+        constant       float & attn_factor,
+        constant       float & beta_fast,
+        constant       float & beta_slow,
         uint  tiitg[[thread_index_in_threadgroup]],
         uint3 tptg[[threads_per_threadgroup]],
         uint3 tgpig[[threadgroup_position_in_grid]]) {
@@ -1125,19 +1168,22 @@ kernel void kernel_rope(
 
     const bool is_neox = mode & 2;
 
+    float corr_dims[2];
+    rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
+
     device const int32_t * pos = src1;
 
     const int64_t p = pos[i2];
 
-    const float theta_0 = freq_scale * (float)p;
+    const float theta_0 = (float)p;
     const float inv_ndims = -1.f/n_dims;
 
     if (!is_neox) {
         for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
 
             const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
-            const float cos_theta = cos(theta);
-            const float sin_theta = sin(theta);
+            float cos_theta, sin_theta;
+            rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
             device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
             device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -1152,9 +1198,12 @@ kernel void kernel_rope(
         for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
             for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
 
-                const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
-                const float cos_theta = cos(theta);
-                const float sin_theta = sin(theta);
+                // simplified from `(ib * n_dims + ic) * inv_ndims`
+                const float cur_rot = inv_ndims*ic - ib;
+
+                const float theta = theta_0 * pow(freq_base, cur_rot);
+                float cos_theta, sin_theta;
+                rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
                 const int64_t i0 = ib*n_dims + ic/2;
 
diff --git a/ggml.c b/ggml.c
index 80d682255328c321089995747b2fbdad361310ea..2c7fe476b176d587e212ad71b5670f4e1b88440e 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1,4 +1,5 @@
 #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#define _USE_MATH_DEFINES // For M_PI on MSVC
 
 #include "ggml-impl.h"
 #include "ggml-quants.h"
@@ -4845,8 +4846,13 @@ static struct ggml_tensor * ggml_rope_impl(
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
         float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow,
         float                 xpos_base,
         bool                  xpos_down,
         bool                  inplace) {
@@ -4862,11 +4868,15 @@ static struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
-    memcpy(params + 4, &freq_base,  sizeof(float));
-    memcpy(params + 5, &freq_scale, sizeof(float));
-    memcpy(params + 6, &xpos_base,  sizeof(float));
-    memcpy(params + 7, &xpos_down,  sizeof(bool));
+    int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(params + 11, &xpos_base,    sizeof(float));
+    memcpy(params + 12, &xpos_down,    sizeof(bool));
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_ROPE;
@@ -4884,7 +4894,9 @@ struct ggml_tensor * ggml_rope(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+    );
 }
 
 struct ggml_tensor * ggml_rope_inplace(
@@ -4894,7 +4906,9 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+    );
 }
 
 struct ggml_tensor * ggml_rope_custom(
@@ -4904,9 +4918,17 @@ struct ggml_tensor * ggml_rope_custom(
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
-        float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+    );
 }
 
 struct ggml_tensor * ggml_rope_custom_inplace(
@@ -4916,9 +4938,17 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
+        int                   n_orig_ctx,
         float                 freq_base,
-        float                 freq_scale) {
-    return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+    );
 }
 
 struct ggml_tensor * ggml_rope_xpos_inplace(
@@ -4928,7 +4958,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace(
         int                   n_dims,
         float                 base,
         bool                  down) {
-    return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true);
+    return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
 }
 
 // ggml_rope_back
@@ -10901,6 +10931,45 @@ static void ggml_compute_forward_clamp(
 
 // ggml_compute_forward_rope
 
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+    return 1 - MIN(1, MAX(0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) {
+    return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
+}
+
+void ggml_rope_yarn_corr_dims(
+    int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = MAX(0,         floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)));
+    dims[1] = MIN(n_dims - 1, ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)));
+}
+
 static void ggml_compute_forward_rope_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10910,21 +10979,26 @@ static void ggml_compute_forward_rope_f32(
         return;
     }
 
-    float freq_base;
-    float freq_scale;
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
     // these two only relevant for xPos RoPE:
     float xpos_base;
     bool  xpos_down;
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
-    memcpy(&xpos_base,  (int32_t *) dst->op_params + 6, sizeof(float));
-    memcpy(&xpos_down,  (int32_t *) dst->op_params + 7, sizeof(bool));
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&xpos_base,   (int32_t *) dst->op_params + 11, sizeof(float));
+    memcpy(&xpos_down,   (int32_t *) dst->op_params + 12, sizeof(bool));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -10952,6 +11026,9 @@ static void ggml_compute_forward_rope_f32(
     int ir = 0;
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.f/n_dims;
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
@@ -10965,18 +11042,18 @@ static void ggml_compute_forward_rope_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = freq_scale * (float)p;
+                float theta_base = (float)p;
 
                 if (is_glm) {
-                    theta = MIN(p, n_ctx - 2);
+                    theta_base = MIN(p, n_ctx - 2);
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
                         const float cos_block_theta = cosf(block_theta);
                         const float sin_block_theta = sinf(block_theta);
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
                         block_theta *= theta_scale;
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -10994,13 +11071,16 @@ static void ggml_compute_forward_rope_f32(
                     }
                 } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        float cos_theta, sin_theta;
+                        rope_yarn(
+                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
+                        );
+
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -11014,12 +11094,19 @@ static void ggml_compute_forward_rope_f32(
                 } else {
                     // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    theta_base *= freq_scale;
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            // simplified from `(ib * n_dims + ic) * inv_ndims`
+                            float cur_rot = inv_ndims * ic - ib;
+
+                            float cos_theta, sin_theta;
+                            rope_yarn(
+                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                &cos_theta, &sin_theta
+                            );
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
                             const int64_t i0 = ib*n_dims + ic/2;
 
@@ -11048,15 +11135,19 @@ static void ggml_compute_forward_rope_f16(
         return;
     }
 
-    float freq_base;
-    float freq_scale;
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_dims = ((int32_t *) dst->op_params)[1];
-    const int mode   = ((int32_t *) dst->op_params)[2];
-    const int n_ctx  = ((int32_t *) dst->op_params)[3];
-    memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
@@ -11084,6 +11175,9 @@ static void ggml_compute_forward_rope_f16(
     int ir = 0;
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float inv_ndims = -1.f/n_dims;
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
@@ -11097,18 +11191,18 @@ static void ggml_compute_forward_rope_f16(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = freq_scale * (float)p;
+                float theta_base = (float)p;
 
                 if (is_glm) {
-                    theta = MIN(p, n_ctx - 2);
+                    theta_base = MIN(p, n_ctx - 2);
                     float block_theta = MAX(p - (n_ctx - 2), 0);
                     for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
                         const float cos_block_theta = cosf(block_theta);
                         const float sin_block_theta = sinf(block_theta);
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
                         block_theta *= theta_scale;
 
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -11126,10 +11220,12 @@ static void ggml_compute_forward_rope_f16(
                     }
                 } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        float cos_theta, sin_theta;
+                        rope_yarn(
+                            theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
+                        );
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -11143,12 +11239,19 @@ static void ggml_compute_forward_rope_f16(
                 } else {
                     // TODO: this might be wrong for ne0 != n_dims - need double check
                     // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    theta_base *= freq_scale;
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            // simplified from `(ib * n_dims + ic) * inv_ndims`
+                            float cur_rot = inv_ndims * ic - ib;
 
-                            theta *= theta_scale;
+                            float cos_theta, sin_theta;
+                            rope_yarn(
+                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                &cos_theta, &sin_theta
+                            );
+
+                            theta_base *= theta_scale;
 
                             const int64_t i0 = ib*n_dims + ic/2;
 
@@ -11256,17 +11359,18 @@ static void ggml_compute_forward_rope_back_f32(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = freq_scale * (float)p;
+                float theta_base = freq_scale * (float)p;
 
                 if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
+
                         // zeta scaling for xPos only:
                         float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
                         if (xpos_down) zeta = 1.0f / zeta;
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
                         const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -11280,10 +11384,10 @@ static void ggml_compute_forward_rope_back_f32(
                 } else {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            const float cos_theta = cosf(theta_base);
+                            const float sin_theta = sinf(theta_base);
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
                             const int64_t i0 = ib*n_dims + ic/2;
 
@@ -11356,14 +11460,14 @@ static void ggml_compute_forward_rope_back_f16(
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                float theta = (float)p;
+                float theta_base = (float)p;
 
                 if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
-                        const float cos_theta = cosf(theta);
-                        const float sin_theta = sinf(theta);
+                        const float cos_theta = cosf(theta_base);
+                        const float sin_theta = sinf(theta_base);
 
-                        theta *= theta_scale;
+                        theta_base *= theta_scale;
 
                         const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                               ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
@@ -11377,10 +11481,10 @@ static void ggml_compute_forward_rope_back_f16(
                 } else {
                     for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
                         for (int64_t ic = 0; ic < n_dims; ic += 2) {
-                            const float cos_theta = cosf(theta);
-                            const float sin_theta = sinf(theta);
+                            const float cos_theta = cosf(theta_base);
+                            const float sin_theta = sinf(theta_base);
 
-                            theta *= theta_scale;
+                            theta_base *= theta_scale;
 
                             const int64_t i0 = ib*n_dims + ic/2;
 
@@ -15505,9 +15609,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src1,
                                 n_dims,
                                 mode,
+                                0,
                                 n_ctx,
                                 freq_base,
                                 freq_scale,
+                                0.0f,
+                                1.0f,
+                                0.0f,
+                                0.0f,
                                 xpos_base,
                                 xpos_down,
                                 false),
diff --git a/ggml.h b/ggml.h
index 9d16c5a72fda0e3439d28376a292932b024fed8d..70eb25a6bf3afc70124ad090508b70d06085f000 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_MAX_CONTEXTS      64
 #define GGML_MAX_SRC           6
 #define GGML_MAX_NAME          64
-#define GGML_MAX_OP_PARAMS     32
+#define GGML_MAX_OP_PARAMS     64
 #define GGML_DEFAULT_N_THREADS 4
 
 #if UINTPTR_MAX == 0xFFFFFFFF
@@ -1326,8 +1326,13 @@ extern "C" {
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
-            float                 freq_scale);
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
 
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
@@ -1337,8 +1342,17 @@ extern "C" {
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
+            int                   n_orig_ctx,
             float                 freq_base,
-            float                 freq_scale);
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    // compute correction dims for YaRN RoPE scaling
+    void ggml_rope_yarn_corr_dims(
+        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
     // xPos RoPE, in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
index 6b7d654294a3e0daae3bf79e10e257449afeecb5..727b4e55495a76e4cf34d658ff258218bab5e0e9 100644 (file)
@@ -7,7 +7,7 @@ import shutil
 import struct
 import sys
 import tempfile
-from enum import IntEnum, auto
+from enum import Enum, IntEnum, auto
 from io import BufferedWriter
 from pathlib import Path
 from typing import IO, Any, BinaryIO, Callable, Sequence
@@ -53,9 +53,12 @@ KEY_ATTENTION_LAYERNORM_EPS     = "{arch}.attention.layer_norm_epsilon"
 KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
 
 # RoPE
-KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count"
-KEY_ROPE_FREQ_BASE       = "{arch}.rope.freq_base"
-KEY_ROPE_SCALE_LINEAR    = "{arch}.rope.scale_linear"
+KEY_ROPE_DIMENSION_COUNT         = "{arch}.rope.dimension_count"
+KEY_ROPE_FREQ_BASE               = "{arch}.rope.freq_base"
+KEY_ROPE_SCALING_TYPE            = "{arch}.rope.scaling.type"
+KEY_ROPE_SCALING_FACTOR          = "{arch}.rope.scaling.factor"
+KEY_ROPE_SCALING_ORIG_CTX_LEN    = "{arch}.rope.scaling.original_context_length"
+KEY_ROPE_SCALING_FINETUNED       = "{arch}.rope.scaling.finetuned"
 
 # tokenization
 KEY_TOKENIZER_MODEL      = "tokenizer.ggml.model"
@@ -577,6 +580,11 @@ class TokenType(IntEnum):
     UNUSED       = 5
     BYTE         = 6
 
+class RopeScalingType(Enum):
+    NONE   = 'none'
+    LINEAR = 'linear'
+    YARN   = 'yarn'
+
 #
 # implementation
 #
@@ -948,8 +956,17 @@ class GGUFWriter:
     def add_rope_freq_base(self, value: float):
         self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
 
-    def add_rope_scale_linear(self, value: float):
-        self.add_float32(KEY_ROPE_SCALE_LINEAR.format(arch=self.arch), value)
+    def add_rope_scaling_type(self, value: RopeScalingType):
+        self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
+
+    def add_rope_scaling_factor(self, value: float):
+        self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
+
+    def add_rope_scaling_orig_ctx_len(self, value: int):
+        self.add_uint32(KEY_ROPE_SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
+
+    def add_rope_scaling_finetuned(self, value: bool):
+        self.add_bool(KEY_ROPE_SCALING_FINETUNED.format(arch=self.arch), value)
 
     def add_tokenizer_model(self, model: str):
         self.add_string(KEY_TOKENIZER_MODEL, model)
index 1c6d482f8fe1bca61ccf79ce7e92497a9da3d030..685882c201921d657a9d1887eb3ceae35b2ce87d 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -54,6 +54,7 @@
 #include <cassert>
 #include <cinttypes>
 #include <climits>
+#include <cmath>
 #include <cstdarg>
 #include <cstddef>
 #include <cstdint>
@@ -235,6 +236,10 @@ enum llm_kv {
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_FREQ_BASE,
     LLM_KV_ROPE_SCALE_LINEAR,
+    LLM_KV_ROPE_SCALING_TYPE,
+    LLM_KV_ROPE_SCALING_FACTOR,
+    LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
+    LLM_KV_ROPE_SCALING_FINETUNED,
 
     LLM_KV_TOKENIZER_MODEL,
     LLM_KV_TOKENIZER_LIST,
@@ -276,9 +281,13 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_LAYERNORM_EPS,       "%s.attention.layer_norm_epsilon"     },
     { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,   "%s.attention.layer_norm_rms_epsilon" },
 
-    { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count" },
-    { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"       },
-    { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"    },
+    { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
+    { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
+    { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"                    },
+    { LLM_KV_ROPE_SCALING_TYPE,             "%s.rope.scaling.type"                    },
+    { LLM_KV_ROPE_SCALING_FACTOR,           "%s.rope.scaling.factor"                  },
+    { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,     "%s.rope.scaling.original_context_length" },
+    { LLM_KV_ROPE_SCALING_FINETUNED,        "%s.rope.scaling.finetuned"               },
 
     { LLM_KV_TOKENIZER_MODEL,               "tokenizer.ggml.model"              },
     { LLM_KV_TOKENIZER_LIST,                "tokenizer.ggml.tokens"             },
@@ -552,6 +561,22 @@ do { \
     } \
 } while (0)
 
+static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
+    { LLAMA_ROPE_SCALING_NONE,   "none"   },
+    { LLAMA_ROPE_SCALING_LINEAR, "linear" },
+    { LLAMA_ROPE_SCALING_YARN,   "yarn"   },
+};
+
+static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
+    for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
+        if (kv.second == name) {
+            return kv.first;
+        }
+    }
+
+    return LLAMA_ROPE_SCALING_UNSPECIFIED;
+}
+
 //
 // ggml helpers
 //
@@ -1035,8 +1060,11 @@ struct llama_hparams {
     float f_norm_eps;
     float f_norm_rms_eps;
 
-    float rope_freq_base_train;
-    float rope_freq_scale_train;
+    float    rope_freq_base_train;
+    float    rope_freq_scale_train;
+    uint32_t n_yarn_orig_ctx;
+    int8_t   rope_scaling_type_train : 3;
+    bool     rope_finetuned : 1;
 
     float f_clamp_kqv;
     float f_max_alibi_bias;
@@ -1051,6 +1079,8 @@ struct llama_hparams {
         if (this->n_layer     != other.n_layer)     return true;
         if (this->n_rot       != other.n_rot)       return true;
         if (this->n_ff        != other.n_ff)        return true;
+        if (this->rope_finetuned  != other.rope_finetuned)  return true;
+        if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
 
         const float EPSILON = 1e-9;
 
@@ -1081,8 +1111,16 @@ struct llama_cparams {
     uint32_t n_threads;       // number of threads to use for generation
     uint32_t n_threads_batch; // number of threads to use for batch processing
 
-    float rope_freq_base;
-    float rope_freq_scale;
+    float    rope_freq_base;
+    float    rope_freq_scale;
+
+    uint32_t n_yarn_orig_ctx;
+    // These hyperparameters are not exposed in GGUF, because all
+    // existing YaRN models use the same values for them.
+    float yarn_ext_factor;
+    float yarn_attn_factor;
+    float yarn_beta_fast;
+    float yarn_beta_slow;
 
     bool mul_mat_q;
 };
@@ -2014,14 +2052,30 @@ static void llm_load_hparams(
     hparams.n_head_kv = hparams.n_head;
     GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
 
+    hparams.rope_finetuned = false;
+    GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
+                 kv(LLM_KV_ROPE_SCALING_FINETUNED));
+
+    hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
+    GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
+                 kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
+
     // rope_freq_base (optional)
     hparams.rope_freq_base_train = 10000.0f;
     GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
 
+    std::string rope_scaling("linear");
+    GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
+    hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
+    GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
+
     // rope_freq_scale (inverse of the kv) is optional
-    float ropescale = 1.0f;
-    GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
-    hparams.rope_freq_scale_train = 1.0f/ropescale;
+    float ropescale = 0.0f;
+    GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
+    if (ropescale == 0.0f) { // try the old key name
+        GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+    }
+    hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
 
     // sanity check for n_rot (optional)
     {
@@ -2371,6 +2425,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     const auto & hparams = model.hparams;
     const auto & vocab   = model.vocab;
 
+    const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
+
     // hparams
     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
@@ -2389,8 +2445,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
     LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
     LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
+    LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type.c_str());
     LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
     LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
+    LLAMA_LOG_INFO("%s: n_yarn_orig_ctx  = %u\n",     __func__, hparams.n_yarn_orig_ctx);
+    LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
     LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
     LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
@@ -3047,21 +3106,11 @@ static void llm_load_tensors(
     model.t_load_us = ggml_time_us() - model.t_start_us;
 }
 
-static bool llama_model_load(
-        const std::string & fname,
-        llama_model & model,
-        int n_gpu_layers,
-        int main_gpu,
-        const float * tensor_split,
-        bool use_mmap,
-        bool use_mlock,
-        bool vocab_only,
-        llama_progress_callback progress_callback,
-        void *progress_callback_user_data) {
+static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
     try {
-        llama_model_loader ml(fname, use_mmap);
+        llama_model_loader ml(fname, params.use_mmap);
 
-        model.hparams.vocab_only = vocab_only;
+        model.hparams.vocab_only = params.vocab_only;
 
         llm_load_arch   (ml, model);
         llm_load_hparams(ml, model);
@@ -3073,15 +3122,15 @@ static bool llama_model_load(
             throw std::runtime_error("vocab size mismatch");
         }
 
-        if (vocab_only) {
+        if (params.vocab_only) {
             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
             return true;
         }
 
         llm_load_tensors(
-                ml, model, n_gpu_layers,
-                main_gpu, tensor_split,
-                use_mlock, progress_callback, progress_callback_user_data);
+            ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
+            params.progress_callback, params.progress_callback_user_data
+        );
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
         return false;
@@ -3150,6 +3199,7 @@ static struct ggml_tensor * llm_build_inp_embd(
 static void llm_build_k_shift(
       struct ggml_context * ctx,
       const llama_hparams & hparams,
+      const llama_cparams & cparams,
      const llama_kv_cache & kv,
        struct ggml_cgraph * graph,
             llm_rope_type   type,
@@ -3162,6 +3212,11 @@ static void llm_build_k_shift(
     const int64_t n_head_kv   = hparams.n_head_kv;
     const int64_t n_embd_gqa  = hparams.n_embd_gqa();
     const int64_t n_embd_head = hparams.n_embd_head();
+    const int32_t n_orig_ctx  = cparams.n_yarn_orig_ctx;
+    const float   ext_factor  = cparams.yarn_ext_factor;
+    const float   attn_factor = cparams.yarn_attn_factor;
+    const float   beta_fast   = cparams.yarn_beta_fast;
+    const float   beta_slow   = cparams.yarn_beta_slow;
 
     GGML_ASSERT(n_embd_head % n_rot == 0);
 
@@ -3185,7 +3240,8 @@ static void llm_build_k_shift(
                         ggml_element_size(kv.k)*n_embd_head,
                         ggml_element_size(kv.k)*n_embd_gqa,
                         ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
-                    K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
+                    K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow);
         cb(tmp, "K_shifted", il);
         ggml_build_forward_expand(graph, tmp);
     }
@@ -3442,12 +3498,17 @@ struct llm_build_context {
 
     const float freq_base;
     const float freq_scale;
+    const float ext_factor;
+    const float attn_factor;
+    const float beta_fast;
+    const float beta_slow;
     const float norm_eps;
     const float norm_rms_eps;
 
     const int32_t n_tokens;
     const int32_t n_kv;     // size of KV cache to consider (n_kv <= n_ctx)
     const int32_t kv_head;  // index of where we store new KV data in the cache
+    const int32_t n_orig_ctx;
 
     const bool do_rope_shift;
 
@@ -3477,11 +3538,16 @@ struct llm_build_context {
         n_embd_gqa    (hparams.n_embd_gqa()),
         freq_base     (cparams.rope_freq_base),
         freq_scale    (cparams.rope_freq_scale),
+        ext_factor    (cparams.yarn_ext_factor),
+        attn_factor   (cparams.yarn_attn_factor),
+        beta_fast     (cparams.yarn_beta_fast),
+        beta_slow     (cparams.yarn_beta_slow),
         norm_eps      (hparams.f_norm_eps),
         norm_rms_eps  (hparams.f_norm_rms_eps),
         n_tokens      (batch.n_tokens),
         n_kv          (worst_case ? n_ctx            : kv_self.n),
         kv_head       (worst_case ? n_ctx - n_tokens : kv_self.head),
+        n_orig_ctx    (cparams.n_yarn_orig_ctx),
         do_rope_shift (worst_case || kv_self.has_shift),
         cb            (cb),
         buf_compute   (lctx.buf_compute) {
@@ -3532,7 +3598,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -3556,10 +3622,18 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                Qcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                    n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                Kcur = ggml_rope_custom(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                    n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(Kcur, "Kcur", il);
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@@ -3634,7 +3708,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -3658,8 +3732,16 @@ struct llm_build_context {
 
                 switch (model.type) {
                     case MODEL_7B:
-                        Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens),    inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
-                        Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+                        Qcur = ggml_rope_custom(
+                            ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                            n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+                        Kcur = ggml_rope_custom(
+                            ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                            n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                        );
                         break;
                     case MODEL_13B:
                         Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
@@ -3746,7 +3828,7 @@ struct llm_build_context {
 
         // shift the entire K-cache if needed
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -3786,10 +3868,16 @@ struct llm_build_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 // using mode = 2 for neox mode
-                Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+                Qcur = ggml_rope_custom(
+                    ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+                Kcur = ggml_rope_custom(
+                    ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(Kcur, "Kcur", il);
 
                 llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@@ -3960,7 +4048,7 @@ struct llm_build_context {
         cb(KQ_mask, "KQ_mask", -1);
 
         if (do_rope_shift) {
-            llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+            llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
         }
 
         for (int il = 0; il < n_layer; ++il) {
@@ -4053,13 +4141,15 @@ struct llm_build_context {
                 cb(kpass, "kpass", il);
 
                 struct ggml_tensor * qrotated = ggml_rope_custom(
-                        ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
-                        );
+                    ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(qrotated, "qrotated", il);
 
                 struct ggml_tensor * krotated = ggml_rope_custom(
-                        ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
-                        );
+                    ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
                 cb(krotated, "krotated", il);
 
                 // ggml currently only supports concatenation on dim=2
@@ -7883,8 +7973,13 @@ struct llama_context_params llama_context_default_params() {
         /*.n_batch                     =*/ 512,
         /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
         /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
+        /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
         /*.rope_freq_base              =*/ 0.0f,
         /*.rope_freq_scale             =*/ 0.0f,
+        /*.yarn_ext_factor             =*/ NAN,
+        /*.yarn_attn_factor            =*/ 1.0f,
+        /*.yarn_beta_fast              =*/ 32.0f,
+        /*.yarn_beta_slow              =*/ 1.0f,
         /*.mul_mat_q                   =*/ true,
         /*.f16_kv                      =*/ true,
         /*.logits_all                  =*/ false,
@@ -7971,10 +8066,7 @@ struct llama_model * llama_load_model_from_file(
         };
     }
 
-    if (!llama_model_load(path_model, *model, params.n_gpu_layers,
-                params.main_gpu, params.tensor_split,
-                params.use_mmap, params.use_mlock, params.vocab_only,
-                params.progress_callback, params.progress_callback_user_data)) {
+    if (!llama_model_load(path_model, *model, params)) {
         LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
         delete model;
         return nullptr;
@@ -8000,13 +8092,35 @@ struct llama_context * llama_new_context_with_model(
     const auto & hparams = model->hparams;
     auto       & cparams = ctx->cparams;
 
-    cparams.n_batch         = params.n_batch;
-    cparams.n_ctx           = params.n_ctx == 0           ? hparams.n_ctx_train           : params.n_ctx;
-    cparams.rope_freq_base  = params.rope_freq_base == 0  ? hparams.rope_freq_base_train  : params.rope_freq_base;
-    cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
-    cparams.n_threads       = params.n_threads;
-    cparams.n_threads_batch = params.n_threads_batch;
-    cparams.mul_mat_q       = params.mul_mat_q;
+    cparams.n_batch          = params.n_batch;
+    cparams.n_threads        = params.n_threads;
+    cparams.n_threads_batch  = params.n_threads_batch;
+    cparams.yarn_ext_factor  = params.yarn_ext_factor;
+    cparams.yarn_attn_factor = params.yarn_attn_factor;
+    cparams.yarn_beta_fast   = params.yarn_beta_fast;
+    cparams.yarn_beta_slow   = params.yarn_beta_slow;
+    cparams.mul_mat_q        = params.mul_mat_q;
+
+    cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
+    cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
+    cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
+
+    cparams.n_yarn_orig_ctx  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
+                               hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
+                                                              hparams.n_ctx_train;
+
+    auto rope_scaling_type = params.rope_scaling_type;
+    if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
+        rope_scaling_type = hparams.rope_scaling_type_train;
+    }
+
+    if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
+        cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
+    }
+
+    if (std::isnan(cparams.yarn_ext_factor)) { // NaN indicates 'not set'
+        cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
+    }
 
     if (params.seed == LLAMA_DEFAULT_SEED) {
         params.seed = time(NULL);
diff --git a/llama.h b/llama.h
index 75fe391ef2e733a40d3af651ec0f4b143e4210b0..3f1becd7616885fc20ebdd21f832bb2fc1ae57f6 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -106,6 +106,14 @@ extern "C" {
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
 
+    enum llama_rope_scaling_type {
+        LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
+        LLAMA_ROPE_SCALING_NONE        = 0,
+        LLAMA_ROPE_SCALING_LINEAR      = 1,
+        LLAMA_ROPE_SCALING_YARN        = 2,
+        LLAMA_ROPE_SCALING_MAX_VALUE   = LLAMA_ROPE_SCALING_YARN,
+    };
+
     typedef struct llama_token_data {
         llama_token id; // token id
         float logit;    // log-odds of the token
@@ -172,10 +180,16 @@ extern "C" {
         uint32_t n_batch;         // prompt processing maximum batch size
         uint32_t n_threads;       // number of threads to use for generation
         uint32_t n_threads_batch; // number of threads to use for batch processing
+        int8_t   rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
 
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
-        float rope_freq_base;  // RoPE base frequency, 0 = from model
-        float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
+        float    rope_freq_base;   // RoPE base frequency, 0 = from model
+        float    rope_freq_scale;  // RoPE frequency scaling factor, 0 = from model
+        float    yarn_ext_factor;  // YaRN extrapolation mix factor, NaN = from model
+        float    yarn_attn_factor; // YaRN magnitude scaling factor
+        float    yarn_beta_fast;   // YaRN low correction dim
+        float    yarn_beta_slow;   // YaRN high correction dim
+        uint32_t yarn_orig_ctx;    // YaRN original context size
 
         // Keep the booleans together to avoid misalignment during copy-by-value.
         bool mul_mat_q;  // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)