]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add phi3 128K model support (#7225)
authorliuwei-git <redacted>
Tue, 21 May 2024 20:28:32 +0000 (04:28 +0800)
committerGitHub <redacted>
Tue, 21 May 2024 20:28:32 +0000 (23:28 +0300)
* add phi3 128k support in convert-hf-to-gguf

* add phi3 128k support in cuda

* address build warnings on llama.cpp

* adjust index value in cuda long rope freq factors

* add long rope support in ggml cpu backend

* make freq factors only depend on ctx size

* remove unused rope scaling type 'su' frin gguf converter

* fix flint warnings on convert-hf-to-gguf.py

* set to the short freq factor when context size is small than trained context size

* add one line of comments

* metal : support rope freq_factors

* ggml : update ggml_rope_ext API to support freq. factors

* backends : add dev messages to support rope freq. factors

* minor : style

* tests : update to use new rope API

* backends : fix pragma semicolons

* minor : cleanup

* llama : move rope factors from KV header to tensors

* llama : remove tmp assert

* cuda : fix compile warning

* convert : read/write n_head_kv

* llama : fix uninitialized tensors

---------

Co-authored-by: Georgi Gerganov <redacted>
15 files changed:
convert-hf-to-gguf.py
examples/finetune/finetune.cpp
examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml-cuda/rope.cu
ggml-kompute.cpp
ggml-metal.m
ggml-metal.metal
ggml-sycl.cpp
ggml-vulkan.cpp
ggml.c
ggml.h
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
llama.cpp
tests/test-backend-ops.cpp

index 6357d40348b34f989d2fe17c7b032478da49450b..daad1c4fc725512b8a3592007e74e790a2930b8b 100755 (executable)
@@ -14,6 +14,7 @@ from pathlib import Path
 from hashlib import sha256
 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
 
+import math
 import numpy as np
 import torch
 
@@ -1784,23 +1785,59 @@ class Phi3MiniModel(Model):
     def set_gguf_parameters(self):
         block_count = self.find_hparam(["num_hidden_layers", "n_layer"])
 
-        rot_pct = 1.0
         n_embd = self.find_hparam(["hidden_size", "n_embd"])
         n_head = self.find_hparam(["num_attention_heads", "n_head"])
+        n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
         rms_eps = self.find_hparam(["rms_norm_eps"])
+        max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
+        orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
+        rope_dims = n_embd // n_head
 
         self.gguf_writer.add_name("Phi3")
-        self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"]))
-
+        self.gguf_writer.add_context_length(max_pos_embds)
+        self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds)
         self.gguf_writer.add_embedding_length(n_embd)
-        self.gguf_writer.add_feed_forward_length(8192)
+        self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"]))
         self.gguf_writer.add_block_count(block_count)
         self.gguf_writer.add_head_count(n_head)
-        self.gguf_writer.add_head_count_kv(n_head)
+        self.gguf_writer.add_head_count_kv(n_head_kv)
         self.gguf_writer.add_layer_norm_rms_eps(rms_eps)
-        self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
+        self.gguf_writer.add_rope_dimension_count(rope_dims)
+        self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
         self.gguf_writer.add_file_type(self.ftype)
 
+        # write rope scaling for long context (128k) model
+        rope_scaling = self.find_hparam(['rope_scaling'], True)
+        if (rope_scaling is None):
+            return
+
+        scale = max_pos_embds / orig_max_pos_embds
+
+        rope_scaling_type = rope_scaling.get('type', '').lower()
+        if len(rope_scaling_type) == 0:
+            raise KeyError('Missing the required key rope_scaling.type')
+
+        if rope_scaling_type == 'su':
+            attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0
+        elif rope_scaling_type == 'yarn':
+            attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0
+        else:
+            raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet')
+
+        self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)
+
+        long_factors = rope_scaling.get('long_factor', None)
+        short_factors = rope_scaling.get('short_factor', None)
+
+        if long_factors is None or short_factors is None:
+            raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
+
+        if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
+            raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
+
+        self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG]  + ".weight", np.array(long_factors, dtype=np.float32))
+        self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
+
 
 @Model.register("PlamoForCausalLM")
 class PlamoModel(Model):
index 22743b1bf02fd01411c3e47a3ebde9b9c017a8e0..992426c1b69e272921c7605592f1790f8973d654 100644 (file)
@@ -563,8 +563,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
         // not capturing these, to silcence warnings
         const int rope_mode = 0;
 
-        return ggml_rope_custom(ctx,
-            t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
+        return ggml_rope_ext(ctx,
+            t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0,
             rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
         );
     };
index 587418cc739640ae0a1260951c4edc82d70b6fed..45bdfa8f5d80ce9d83abe648aaeb4a2c97b3b0a1 100644 (file)
@@ -301,8 +301,8 @@ static struct ggml_tensor * llama_build_train_graphs(
         // not capturing these, to silcence warnings
         const int rope_mode = 0;
 
-        return ggml_rope_custom(
-            ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
+        return ggml_rope_ext(
+            ctx, t, KQ_pos, nullptr, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
         );
     };
 
index 4b0d2e5adbbc58fbd94f5837f02b10fa09cd4ec1..4a558f4b3757e49dd3b2849a3fd4be6d8300ac55 100644 (file)
@@ -58,10 +58,10 @@ static __global__ void rope(
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-template<typename T, bool has_pos>
+template<typename T, bool has_pos, bool has_freq_facs>
 static __global__ void rope_neox(
     const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
 ) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
@@ -88,7 +88,9 @@ static __global__ void rope_neox(
     float cur_rot = inv_ndims * ic - ib;
 
     const int p = has_pos ? pos[i2] : 0;
-    const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
+    const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
+
+    const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
 
     float cos_theta, sin_theta;
     rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -164,7 +166,7 @@ static void rope_cuda(
 template<typename T>
 static void rope_neox_cuda(
     const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
 ) {
     GGML_ASSERT(ncols % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
@@ -175,15 +177,29 @@ static void rope_neox_cuda(
     const float inv_ndims = -1.0f / n_dims;
 
     if (pos == nullptr) {
-        rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-            theta_scale, inv_ndims
-        );
+        if (freq_factors == nullptr) {
+            rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, inv_ndims, freq_factors
+                );
+        } else {
+            rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, inv_ndims, freq_factors
+                );
+        }
     } else {
-        rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
-            x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-            theta_scale, inv_ndims
-        );
+        if (freq_factors == nullptr) {
+            rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, inv_ndims, freq_factors
+                );
+        } else {
+            rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
+                x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, inv_ndims, freq_factors
+                );
+        }
     }
 }
 
@@ -214,24 +230,27 @@ static void rope_cuda_f32(
 
 static void rope_neox_cuda_f16(
     const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
 
-    rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 static void rope_neox_cuda_f32(
     const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
 ) {
 
-    rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+    rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
 }
 
 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
     const float * src0_d = (const float *)src0->data;
     const float * src1_d = (const float *)src1->data;
+
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
@@ -241,7 +260,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    const int64_t ne2 = dst->ne[2];
     const int64_t nrows = ggml_nrows(src0);
 
     //const int n_past      = ((int32_t *) dst->op_params)[0];
@@ -259,16 +277,22 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
+    const float * freq_factors = nullptr;
     const int32_t * pos = nullptr;
-    if ((mode & 1) == 0) {
-        GGML_ASSERT(src1->type == GGML_TYPE_I32);
-        GGML_ASSERT(src1->ne[0] == ne2);
-        pos = (const int32_t *) src1_d;
-    }
 
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    if (is_neox) {
+        pos = (const int32_t *) src1_d;
+
+        if (src2 != nullptr) {
+            freq_factors = (const float *) src2->data;
+        }
+    } else {
+        GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
+    }
+
     rope_corr_dims corr_dims;
     ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
 
@@ -280,12 +304,12 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         if (src0->type == GGML_TYPE_F32) {
             rope_neox_cuda_f32(
                 (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else if (src0->type == GGML_TYPE_F16) {
             rope_neox_cuda_f16(
                 (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, stream
+                attn_factor, corr_dims, freq_factors, stream
             );
         } else {
             GGML_ASSERT(false);
index 3f033d58be481544d3da127a8a23f6164cbfb465..6c6058b2a95b12f3eed9237bdf698e32460eb7fd 100644 (file)
@@ -1677,6 +1677,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                     } break;
                 case GGML_OP_ROPE:
                     {
+#pragma message("TODO: implement phi3 frequency factors support")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
+                        GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
+
                         GGML_ASSERT(ne10 == ne02);
                         GGML_ASSERT(src0t == dstt);
                         // const int n_past = ((int32_t *) dst->op_params)[0];
index b0b16dbf7716096a9c0f7cab1bee092500525289..5d5ad20ada7884c95e01ce10d8c908593349c369 100644 (file)
@@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute(
             const int64_t  ne10 = src1 ? src1->ne[0] : 0;
             const int64_t  ne11 = src1 ? src1->ne[1] : 0;
             const int64_t  ne12 = src1 ? src1->ne[2] : 0;
-            const int64_t  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+            const int64_t  ne13 = src1 ? src1->ne[3] : 0;
 
             const uint64_t nb10 = src1 ? src1->nb[0] : 0;
             const uint64_t nb11 = src1 ? src1->nb[1] : 0;
             const uint64_t nb12 = src1 ? src1->nb[2] : 0;
-            const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+            const uint64_t nb13 = src1 ? src1->nb[3] : 0;
 
-            const int64_t  ne0  = dst ? dst->ne[0] : 0;
-            const int64_t  ne1  = dst ? dst->ne[1] : 0;
-            const int64_t  ne2  = dst ? dst->ne[2] : 0;
-            const int64_t  ne3  = dst ? dst->ne[3] : 0;
+            const int64_t  ne20 = src2 ? src2->ne[0] : 0;
+            const int64_t  ne21 = src2 ? src2->ne[1] : 0;
+            const int64_t  ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
+            const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
 
-            const uint64_t nb0  = dst ? dst->nb[0] : 0;
-            const uint64_t nb1  = dst ? dst->nb[1] : 0;
-            const uint64_t nb2  = dst ? dst->nb[2] : 0;
-            const uint64_t nb3  = dst ? dst->nb[3] : 0;
+            const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+            const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+            const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+            const uint64_t nb23 = src2 ? src2->nb[3] : 0;
+
+            const int64_t  ne0  =  dst ?  dst->ne[0] : 0;
+            const int64_t  ne1  =  dst ?  dst->ne[1] : 0;
+            const int64_t  ne2  =  dst ?  dst->ne[2] : 0;
+            const int64_t  ne3  =  dst ?  dst->ne[3] : 0;
+
+            const uint64_t nb0  =  dst ?  dst->nb[0] : 0;
+            const uint64_t nb1  =  dst ?  dst->nb[1] : 0;
+            const uint64_t nb2  =  dst ?  dst->nb[2] : 0;
+            const uint64_t nb3  =  dst ?  dst->nb[3] : 0;
 
             const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
             const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
@@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute(
                         const int n_as = src0->ne[2];
 
                         // src2 = ids
-                        const int64_t  ne20 = src2->ne[0];
-                        const int64_t  ne21 = src2->ne[1];
-                        const int64_t  ne22 = src2->ne[2]; GGML_UNUSED(ne22);
-                        const int64_t  ne23 = src2->ne[3]; GGML_UNUSED(ne23);
-
-                        const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
-                        const uint64_t nb21 = src2->nb[1];
-                        const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
-                        const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
-
                         const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
 
                         GGML_ASSERT(src2t == GGML_TYPE_I32);
@@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute(
                         // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
                         const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
 
-                        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+                        float freq_base;
+                        float freq_scale;
+                        float ext_factor;
+                        float attn_factor;
+                        float beta_fast;
+                        float beta_slow;
+
                         memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
                         memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
                         memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
@@ -2252,6 +2258,15 @@ static enum ggml_status ggml_metal_graph_compute(
                         memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
                         memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
+                        const bool is_neox = mode & 2;
+                        const bool is_glm  = mode & 4;
+
+                        GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
+
+                        if (!is_neox) {
+                            GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
+                        }
+
                         id<MTLComputePipelineState> pipeline = nil;
 
                         switch (src0->type) {
@@ -2263,33 +2278,38 @@ static enum ggml_status ggml_metal_graph_compute(
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
                         [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
-                        [encoder setBuffer:id_dst      offset:offs_dst         atIndex:2];
-                        [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:3];
-                        [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:4];
-                        [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:5];
-                        [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:6];
-                        [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:7];
-                        [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:8];
-                        [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:9];
-                        [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:10];
-                        [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:11];
-                        [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:12];
-                        [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:13];
-                        [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:14];
-                        [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:15];
-                        [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:16];
-                        [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:17];
-                        [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:18];
-                        [encoder setBytes:&n_past      length:sizeof(     int) atIndex:19];
-                        [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:20];
-                        [encoder setBytes:&mode        length:sizeof(     int) atIndex:21];
-                        [encoder setBytes:&n_orig_ctx  length:sizeof(     int) atIndex:22];
-                        [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
-                        [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
-                        [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
-                        [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
-                        [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
-                        [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
+                        if (id_src2 != nil) {
+                            [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
+                        } else {
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:2];
+                        }
+                        [encoder setBuffer:id_dst      offset:offs_dst         atIndex:3];
+                        [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:4];
+                        [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:5];
+                        [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:6];
+                        [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:7];
+                        [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:8];
+                        [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:9];
+                        [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:10];
+                        [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:11];
+                        [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:12];
+                        [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:13];
+                        [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:14];
+                        [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:15];
+                        [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:16];
+                        [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:17];
+                        [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:18];
+                        [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
+                        [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
+                        [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
+                        [encoder setBytes:&mode        length:sizeof(     int) atIndex:22];
+                        [encoder setBytes:&n_orig_ctx  length:sizeof(     int) atIndex:23];
+                        [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:24];
+                        [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:25];
+                        [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:26];
+                        [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:27];
+                        [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:28];
+                        [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:29];
 
                         [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                     } break;
@@ -2535,11 +2555,6 @@ static enum ggml_status ggml_metal_graph_compute(
                         GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
                                 "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
 
-                        const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
-                        const uint64_t nb21 = src2 ? src2->nb[1] : 0;
-                        const uint64_t nb22 = src2 ? src2->nb[2] : 0;
-                        const uint64_t nb23 = src2 ? src2->nb[3] : 0;
-
                         const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
                       //const int64_t  ne31 = src3 ? src3->ne[1] : 0;
                         const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
index cf262e8349874257357bac973b504f892f2e82f9..c5eb2528083779b4ac11e900ca92ea6e1043101a 100644 (file)
@@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims(
 typedef void (rope_t)(
         device const    void * src0,
         device const int32_t * src1,
+        device const   float * src2,
         device         float * dst,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -1675,6 +1676,7 @@ template<typename T>
 kernel void kernel_rope(
         device const    void * src0,
         device const int32_t * src1,
+        device const   float * src2,
         device         float * dst,
         constant     int64_t & ne00,
         constant     int64_t & ne01,
@@ -1744,8 +1746,10 @@ kernel void kernel_rope(
 
                 // simplified from `(ib * n_dims + ic) * inv_ndims`
                 const float cur_rot = inv_ndims*ic - ib;
+                const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
+
+                const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
 
-                const float theta = theta_0 * pow(freq_base, cur_rot);
                 float cos_theta, sin_theta;
                 rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
index eac8f557967355a00e7efb75dbc92e2fbe63bb60..f486b6c0a5a3b34a6565818c3f81726212c8d585 100644 (file)
@@ -14454,6 +14454,9 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
                               ggml_tensor *dst, const float *src0_dd,
                               const float *src1_dd, float *dst_dd,
                               const dpct::queue_ptr &main_stream) {
+#pragma message("TODO: implement phi3 frequency factors support")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
+    GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
index aff451b6354e59ec81576a968497471e3a86eaee..16287a28089a0c20f81d7db7e4e322bc07353d9b 100644 (file)
@@ -4238,6 +4238,10 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
 }
 
 static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+#pragma message("TODO: implement phi3 frequency factors support")
+#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
+    GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
+
     const int n_dims        = ((int32_t *) dst->op_params)[1];
     const int mode          = ((int32_t *) dst->op_params)[2];
     // const int n_ctx         = ((int32_t *) dst->op_params)[3];
diff --git a/ggml.c b/ggml.c
index 4bd911528586bb5d54ebdc8dae2bee137a757c92..37b16b7a9ce7f2bd7a0dd323d41c5c5b6a787d6e 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -6231,6 +6231,7 @@ static struct ggml_tensor * ggml_rope_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
@@ -6248,6 +6249,11 @@ static struct ggml_tensor * ggml_rope_impl(
     GGML_ASSERT(b->type == GGML_TYPE_I32);
     GGML_ASSERT(a->ne[2] == b->ne[0]);
 
+    if (c) {
+        GGML_ASSERT(c->type == GGML_TYPE_F32);
+        GGML_ASSERT(c->ne[0] >= n_dims / 2);
+    }
+
     bool is_node = false;
 
     if (a->grad) {
@@ -6271,6 +6277,7 @@ static struct ggml_tensor * ggml_rope_impl(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = b;
+    result->src[2] = c;
 
     return result;
 }
@@ -6283,7 +6290,7 @@ struct ggml_tensor * ggml_rope(
         int                   mode,
         int                   n_ctx) {
     return ggml_rope_impl(
-        ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+        ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
     );
 }
 
@@ -6295,14 +6302,15 @@ struct ggml_tensor * ggml_rope_inplace(
         int                   mode,
         int                   n_ctx) {
     return ggml_rope_impl(
-        ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+        ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
     );
 }
 
-struct ggml_tensor * ggml_rope_custom(
+struct ggml_tensor * ggml_rope_ext(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
@@ -6314,15 +6322,16 @@ struct ggml_tensor * ggml_rope_custom(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
         ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
     );
 }
 
-struct ggml_tensor * ggml_rope_custom_inplace(
+struct ggml_tensor * ggml_rope_ext_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
@@ -6334,19 +6343,49 @@ struct ggml_tensor * ggml_rope_custom_inplace(
         float                 beta_fast,
         float                 beta_slow) {
     return ggml_rope_impl(
-        ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
         ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
     );
 }
 
-struct ggml_tensor * ggml_rope_xpos_inplace(
+struct ggml_tensor * ggml_rope_custom(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         int                   n_dims,
-        float                 base,
-        bool                  down) {
-    return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
+        int                   mode,
+        int                   n_ctx,
+        int                   n_orig_ctx,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+    );
+}
+
+struct ggml_tensor * ggml_rope_custom_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx,
+        int                   n_orig_ctx,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+    );
 }
 
 // ggml_rope_back
@@ -6355,6 +6394,7 @@ struct ggml_tensor * ggml_rope_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
         int                   n_dims,
         int                   mode,
         int                   n_ctx,
@@ -6370,6 +6410,7 @@ struct ggml_tensor * ggml_rope_back(
     GGML_ASSERT(ggml_is_vector(b));
     GGML_ASSERT(b->type == GGML_TYPE_I32);
     GGML_ASSERT(a->ne[2] == b->ne[0]);
+    GGML_ASSERT(c == NULL && "freq factors not implemented yet");
 
     GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
 
@@ -14304,6 +14345,7 @@ static void ggml_compute_forward_rope_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src2 = dst->src[2];
 
     if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
         return;
@@ -14363,6 +14405,17 @@ static void ggml_compute_forward_rope_f32(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    const float * freq_factors = NULL;
+    if (is_neox) {
+        if (src2 != NULL) {
+            GGML_ASSERT(src2->type == GGML_TYPE_F32);
+            GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+            freq_factors = (const float *) src2->data;
+        }
+    } else {
+        GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
+    }
+
     // backward process uses inverse rotation by cos and sin.
     // cos and sin build a rotation matrix, where the inverse is the transpose.
     // this essentially just switches the sign of sin.
@@ -14439,10 +14492,11 @@ static void ggml_compute_forward_rope_f32(
 
                             // simplified from `(ib * n_dims + ic) * inv_ndims`
                             float cur_rot = inv_ndims * ic - ib;
+                            float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
 
                             float cos_theta, sin_theta;
                             rope_yarn(
-                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
                                 &cos_theta, &sin_theta
                             );
                             sin_theta *= sin_sign;
@@ -18387,6 +18441,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg
 static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
     struct ggml_tensor * src0 = tensor->src[0];
     struct ggml_tensor * src1 = tensor->src[1];
+    struct ggml_tensor * src2 = tensor->src[2];
 
     switch (tensor->op) {
         case GGML_OP_DUP:
@@ -18918,6 +18973,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_rope_back(ctx,
                                 tensor->grad,
                                 src1,
+                                src2,
                                 n_dims,
                                 mode,
                                 n_ctx,
@@ -18957,6 +19013,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             ggml_rope_impl(ctx,
                                 tensor->grad,
                                 src1,
+                                src2,
                                 n_dims,
                                 mode,
                                 n_ctx,
@@ -19038,7 +19095,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             masked);
                 }
 
-                struct ggml_tensor * src2 = tensor->src[2];
                 const int64_t elem_q = ggml_nelements(src0);
                 const int64_t elem_k = ggml_nelements(src1);
                 const int64_t elem_v = ggml_nelements(src2);
diff --git a/ggml.h b/ggml.h
index 77475710129d7986b1684e56fc45fbd28cb82d5d..35ac9110ceb17034602c2c79c8404c65ed36782c 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1465,6 +1465,7 @@ extern "C" {
     // if mode & 4 == 1, ChatGLM style
     //
     // b is an int32 vector with size a->ne[2], it contains the positions
+    // c is freq factors (e.g. phi3-128k), (optional)
     GGML_API struct ggml_tensor * ggml_rope(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1483,10 +1484,11 @@ extern "C" {
             int                   n_ctx);
 
     // custom RoPE
-    GGML_API struct ggml_tensor * ggml_rope_custom(
+    GGML_API struct ggml_tensor * ggml_rope_ext(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
@@ -1499,10 +1501,11 @@ extern "C" {
             float                 beta_slow);
 
     // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+    GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
@@ -1514,18 +1517,41 @@ extern "C" {
             float                 beta_fast,
             float                 beta_slow);
 
-    // compute correction dims for YaRN RoPE scaling
-    GGML_CALL void ggml_rope_yarn_corr_dims(
-        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx,
+            int                   n_orig_ctx,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow),
+        "use ggml_rope_ext instead");
 
-    // xPos RoPE, in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
             int                   n_dims,
-            float                 base,
-            bool                  down);
+            int                   mode,
+            int                   n_ctx,
+            int                   n_orig_ctx,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow),
+        "use ggml_rope_ext_inplace instead");
+
+    // compute correction dims for YaRN RoPE scaling
+    GGML_CALL void ggml_rope_yarn_corr_dims(
+        int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
@@ -1533,6 +1559,7 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
             int                   n_dims,
             int                   mode,
             int                   n_ctx,
index 692120f4d64b0af1384b1f7e805384e37b8871da..42df2e4d00604ce988489b0f61c4350e7ca49d15 100644 (file)
@@ -57,12 +57,13 @@ class Keys:
         CAUSAL            = "{arch}.attention.causal"
 
     class Rope:
-        DIMENSION_COUNT      = "{arch}.rope.dimension_count"
-        FREQ_BASE            = "{arch}.rope.freq_base"
-        SCALING_TYPE         = "{arch}.rope.scaling.type"
-        SCALING_FACTOR       = "{arch}.rope.scaling.factor"
-        SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
-        SCALING_FINETUNED    = "{arch}.rope.scaling.finetuned"
+        DIMENSION_COUNT         = "{arch}.rope.dimension_count"
+        FREQ_BASE               = "{arch}.rope.freq_base"
+        SCALING_TYPE            = "{arch}.rope.scaling.type"
+        SCALING_FACTOR          = "{arch}.rope.scaling.factor"
+        SCALING_ATTN_FACTOR     = "{arch}.rope.scaling.attn_factor"
+        SCALING_ORIG_CTX_LEN    = "{arch}.rope.scaling.original_context_length"
+        SCALING_FINETUNED       = "{arch}.rope.scaling.finetuned"
 
     class SSM:
         CONV_KERNEL    = "{arch}.ssm.conv_kernel"
@@ -148,6 +149,8 @@ class MODEL_TENSOR(IntEnum):
     OUTPUT             = auto()
     OUTPUT_NORM        = auto()
     ROPE_FREQS         = auto()
+    ROPE_FACTORS_LONG  = auto()
+    ROPE_FACTORS_SHORT = auto()
     ATTN_Q             = auto()
     ATTN_K             = auto()
     ATTN_V             = auto()
@@ -225,6 +228,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.OUTPUT_NORM:        "output_norm",
     MODEL_TENSOR.OUTPUT:             "output",
     MODEL_TENSOR.ROPE_FREQS:         "rope_freqs",
+    MODEL_TENSOR.ROPE_FACTORS_LONG:  "rope_factors_long",
+    MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
     MODEL_TENSOR.ATTN_NORM:          "blk.{bid}.attn_norm",
     MODEL_TENSOR.ATTN_NORM_2:        "blk.{bid}.attn_norm_2",
     MODEL_TENSOR.ATTN_QKV:           "blk.{bid}.attn_qkv",
index d5e323a52ef14005d0798777d3803739c8088ed0..8b41b54eaa5a67654875ff1099088c14dbd31ac1 100644 (file)
@@ -433,6 +433,9 @@ class GGUFWriter:
     def add_rope_scaling_factor(self, value: float) -> None:
         self.add_float32(Keys.Rope.SCALING_FACTOR.format(arch=self.arch), value)
 
+    def add_rope_scaling_attn_factors(self, value: Sequence[float]) -> None:
+        self.add_float32(Keys.Rope.SCALING_ATTN_FACTOR.format(arch=self.arch), value)
+
     def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
         self.add_uint32(Keys.Rope.SCALING_ORIG_CTX_LEN.format(arch=self.arch), value)
 
index d26fe559a2051623ed7f4c5ba180890dab08f3c8..abff8c1c03e7a24ad48ab64a2c0eed5025e5ea9b 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -304,6 +304,7 @@ enum llm_kv {
     LLM_KV_ROPE_SCALE_LINEAR,
     LLM_KV_ROPE_SCALING_TYPE,
     LLM_KV_ROPE_SCALING_FACTOR,
+    LLM_KV_ROPE_SCALING_ATTN_FACTOR,
     LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
     LLM_KV_ROPE_SCALING_FINETUNED,
 
@@ -381,6 +382,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"                    },
     { LLM_KV_ROPE_SCALING_TYPE,             "%s.rope.scaling.type"                    },
     { LLM_KV_ROPE_SCALING_FACTOR,           "%s.rope.scaling.factor"                  },
+    { LLM_KV_ROPE_SCALING_ATTN_FACTOR,      "%s.rope.scaling.attn_factor"             },
     { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,     "%s.rope.scaling.original_context_length" },
     { LLM_KV_ROPE_SCALING_FINETUNED,        "%s.rope.scaling.finetuned"               },
 
@@ -436,6 +438,8 @@ enum llm_tensor {
     LLM_TENSOR_OUTPUT,
     LLM_TENSOR_OUTPUT_NORM,
     LLM_TENSOR_ROPE_FREQS,
+    LLM_TENSOR_ROPE_FACTORS_LONG,
+    LLM_TENSOR_ROPE_FACTORS_SHORT,
     LLM_TENSOR_ATTN_Q,
     LLM_TENSOR_ATTN_K,
     LLM_TENSOR_ATTN_V,
@@ -803,18 +807,20 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
     {
         LLM_ARCH_PHI3,
         {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ROPE_FACTORS_LONG,  "rope_factors_long" },
+            { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,           "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
         },
     },
     {
@@ -1750,6 +1756,7 @@ struct llama_hparams {
     float f_norm_eps;
     float f_norm_rms_eps;
 
+    float    rope_attn_factor = 1.0f;
     float    rope_freq_base_train;
     float    rope_freq_scale_train;
     uint32_t n_yarn_orig_ctx;
@@ -1798,6 +1805,7 @@ struct llama_hparams {
 
         if (!is_float_close(this->f_norm_eps,            other.f_norm_eps,            EPSILON)) return true;
         if (!is_float_close(this->f_norm_rms_eps,        other.f_norm_rms_eps,        EPSILON)) return true;
+        if (!is_float_close(this->rope_attn_factor,      other.rope_attn_factor,      EPSILON)) return true;
         if (!is_float_close(this->rope_freq_base_train,  other.rope_freq_base_train,  EPSILON)) return true;
         if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
 
@@ -2103,6 +2111,10 @@ struct llama_model {
     struct ggml_tensor * output;
     struct ggml_tensor * output_b;
 
+    // long rope factors
+    struct ggml_tensor * rope_long  = nullptr;
+    struct ggml_tensor * rope_short = nullptr;
+
     std::vector<llama_layer> layers;
 
     llama_split_mode split_mode;
@@ -3306,6 +3318,39 @@ struct llama_model_loader {
         return get_arr_n(llm_kv(kid), result, required);
     }
 
+    template<typename T>
+    bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true) {
+        const int kid = gguf_find_key(meta, key.c_str());
+
+        if (kid < 0) {
+            if (required) {
+                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
+            }
+            return false;
+        }
+
+        struct GGUFMeta::ArrayInfo arr_info =
+            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
+
+        if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) {
+            throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str()));
+        }
+
+        // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
+        GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same<T, float>::value));
+        GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32   || std::is_same<T, int>::value));
+
+        result.resize(arr_info.length);
+        result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
+
+        return true;
+    }
+
+    template<typename T>
+    bool get_arr(const enum llm_kv kid, T& result, const bool required = true) {
+        return get_arr(llm_kv(kid), result, required);
+    }
+
     template<typename T>
     bool get_key(const std::string & key, T & result, const bool required = true) {
         auto it = kv_overrides.find(key);
@@ -3849,6 +3894,8 @@ static void llm_load_hparams(
     }
     hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
 
+    ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
+
     // sanity check for n_rot (optional)
     {
         hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head;
@@ -4880,6 +4927,7 @@ static bool llm_load_tensors(
     // create tensors for the weights
     {
         const int64_t n_embd       = hparams.n_embd;
+        const int64_t n_embd_head  = n_embd / hparams.n_head;
         const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
         const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
         const int64_t n_embd_gqa   = n_embd_v_gqa;
@@ -5591,6 +5639,9 @@ static bool llm_load_tensors(
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
 
+                    model.rope_long  = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight"), { n_embd_head/2 }, false);
+                    model.rope_short = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, false);
+
                     // output
                     {
                         model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
@@ -5601,12 +5652,12 @@ static bool llm_load_tensors(
                         ggml_context* ctx_layer = ctx_for_layer(i);
                         ggml_context* ctx_split = ctx_for_layer_split(i);
 
-                        auto& layer = model.layers[i];
+                        auto & layer = model.layers[i];
 
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd });
 
                         layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false);
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd });
+                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd });
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd });
 
@@ -6821,17 +6872,20 @@ struct llm_build_context {
         cb(lctx.inp_K_shift, "K_shift", -1);
         ggml_set_input(lctx.inp_K_shift);
 
+        struct ggml_tensor * rope_factors = build_rope_factors();
+
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * tmp =
                 // we rotate only the first n_rot dimensions
-                ggml_rope_custom_inplace(ctx0,
+                ggml_rope_ext_inplace(ctx0,
                         ggml_view_3d(ctx0, kv_self.k_l[il],
                             n_embd_head_k, n_head_kv, n_ctx,
                             ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
                             ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
                             0),
-                        lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
+
             cb(tmp, "K_shifted", il);
             ggml_build_forward_expand(gf, tmp);
         }
@@ -6934,6 +6988,17 @@ struct llm_build_context {
         return lctx.inp_pos;
     }
 
+    struct ggml_tensor * build_rope_factors() {
+        // choose long/short freq factors based on the context size
+        const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
+
+        if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) {
+            return model.rope_long;
+        }
+
+        return model.rope_short;
+    }
+
     struct ggml_tensor * build_inp_out_ids() {
         lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
         cb(lctx.inp_out_ids, "inp_out_ids", -1);
@@ -7041,15 +7106,15 @@ struct llm_build_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -7171,13 +7236,13 @@ struct llm_build_context {
 
                 switch (model.type) {
                     case MODEL_7B:
-                        Qcur = ggml_rope_custom(
-                            ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                        Qcur = ggml_rope_ext(
+                            ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                             n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
-                        Kcur = ggml_rope_custom(
-                            ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                        Kcur = ggml_rope_ext(
+                            ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                             n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
@@ -7283,15 +7348,15 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -7404,14 +7469,14 @@ struct llm_build_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 // using mode = 2 for neox mode
-                Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -7527,15 +7592,15 @@ struct llm_build_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -7679,15 +7744,15 @@ struct llm_build_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -8032,15 +8097,15 @@ struct llm_build_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -8472,15 +8537,15 @@ struct llm_build_context {
                 }
 
 
-                Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -8592,14 +8657,14 @@ struct llm_build_context {
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
                 // using mode = 2 for neox mode
-                Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -8703,15 +8768,15 @@ struct llm_build_context {
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -8817,15 +8882,15 @@ struct llm_build_context {
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -8969,8 +9034,8 @@ struct llm_build_context {
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-                Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
@@ -8980,8 +9045,8 @@ struct llm_build_context {
                 Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9052,6 +9117,9 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
+        // rope freq factors for 128k context
+        struct ggml_tensor * rope_factors = build_rope_factors();
+
         for (int il = 0; il < n_layer; ++il) {
             auto residual = inpL;
 
@@ -9088,8 +9156,8 @@ struct llm_build_context {
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-                Qcur = ggml_rope_custom(
-                    ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
@@ -9097,8 +9165,8 @@ struct llm_build_context {
                 Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head)));
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx,
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
@@ -9204,14 +9272,14 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_rope_custom(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head,    n_tokens), inp_pos, nullptr,
                         n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
                         n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
@@ -9412,15 +9480,15 @@ struct llm_build_context {
                 cb(tmpk, "tmpk", il);
                 cb(Vcur, "Vcur", il);
 
-                struct ggml_tensor * Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos,
+                struct ggml_tensor * Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                struct ggml_tensor * Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -9528,15 +9596,15 @@ struct llm_build_context {
                 //     cb(Vcur, "Vcur", il);
                 // }
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -9645,15 +9713,15 @@ struct llm_build_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -9775,15 +9843,15 @@ struct llm_build_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -9895,8 +9963,8 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
-                Qcur = ggml_rope_custom(
-                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
                         n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
@@ -9904,8 +9972,8 @@ struct llm_build_context {
                 Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
                 cb(Qcur, "Qcur_scaled", il);
 
-                Kcur = ggml_rope_custom(
-                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
                         n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
@@ -10015,15 +10083,15 @@ struct llm_build_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -10305,15 +10373,15 @@ struct llm_build_context {
                     cb(Kcur, "Kcur", il);
                 }
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -10436,15 +10504,15 @@ struct llm_build_context {
                     cb(Vcur, "Vcur", il);
                 }
 
-                Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
-                Kcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
                     n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -15417,6 +15485,7 @@ struct llama_context * llama_new_context_with_model(
         cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
     }
 
+    cparams.yarn_attn_factor *= hparams.rope_attn_factor;
     cparams.causal_attn = hparams.causal_attn;
 
     if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
index c74e253db4b3bd428bef837e0554490574d684d3..1493a7ca7c405e4f125325dcb110d28df2f03015 100644 (file)
@@ -1763,14 +1763,14 @@ struct test_llama : public test_llm {
                 struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
 
-                Qcur = ggml_rope_custom(
-                    ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos,
+                Qcur = ggml_rope_ext(
+                    ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos, nullptr,
                     hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
-                Kcur = ggml_rope_custom(
-                    ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
+                Kcur = ggml_rope_ext(
+                    ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
                     hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -1889,13 +1889,13 @@ struct test_falcon : public test_llm {
                 Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
 
                 // using mode = 2 for neox mode
-                Qcur = ggml_rope_custom(
-                    ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
+                Qcur = ggml_rope_ext(
+                    ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );
 
-                Kcur = ggml_rope_custom(
-                    ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
+                Kcur = ggml_rope_ext(
+                    ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
                     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
                 );