]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : change ggml_scale to take a float instead of tensor (#4573)
authorGeorgi Gerganov <redacted>
Thu, 21 Dec 2023 21:20:49 +0000 (23:20 +0200)
committerGitHub <redacted>
Thu, 21 Dec 2023 21:20:49 +0000 (23:20 +0200)
* ggml : change ggml_scale to take a float instead of tensor

* ggml : fix CPU implementation

* tests : fix test-grad0

ggml-ci

12 files changed:
examples/baby-llama/baby-llama.cpp
examples/export-lora/export-lora.cpp
examples/finetune/finetune.cpp
examples/llava/clip.cpp
examples/train-text-from-scratch/train-text-from-scratch.cpp
ggml-cuda.cu
ggml-metal.m
ggml.c
ggml.h
llama.cpp
tests/test-backend-ops.cpp
tests/test-grad0.cpp

index 2dc2988d34c81d75e9c56bac8424f945f6df9453..e7d2ad592e4c978bd164d49f664b09151a8bf227 100644 (file)
@@ -575,10 +575,7 @@ static struct ggml_tensor * forward(
 
             // KQ_scaled = KQ / sqrt(n_embd/n_head)
             // KQ_scaled shape [n_past + N, N, n_head, 1]
-            struct ggml_tensor * KQ_scaled =
-                ggml_scale(ctx0,
-                        KQ,
-                        ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
+            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
 
             // KQ_masked = mask_past(KQ_scaled)
             // KQ_masked shape [n_past + N, N, n_head, 1]
@@ -844,10 +841,7 @@ static struct ggml_tensor * forward_batch(
 
             // KQ_scaled = KQ / sqrt(n_embd/n_head)
             // KQ_scaled shape [n_past + N, N, n_head, n_batch]
-            struct ggml_tensor * KQ_scaled =
-                ggml_scale(ctx0,
-                        KQ,
-                        ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
+            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
             assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);
 
             // KQ_masked = mask_past(KQ_scaled)
@@ -1131,10 +1125,7 @@ static struct ggml_tensor * forward_lora(
 
             // KQ_scaled = KQ / sqrt(n_embd/n_head)
             // KQ_scaled shape [n_past + N, N, n_head, 1]
-            struct ggml_tensor * KQ_scaled =
-                ggml_scale(ctx0,
-                        KQ,
-                        ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
+            struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
 
             // KQ_masked = mask_past(KQ_scaled)
             // KQ_masked shape [n_past + N, N, n_head, 1]
index c8754ce70f37df5242ee3d017d53eaae5b6f24ad..58fbe204d3bbbde40a8dbe43810bb8e64502bcb2 100644 (file)
@@ -309,7 +309,7 @@ static struct ggml_cgraph * build_graph_lora(
 ) {
     struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
     if (scaling != 1.0f) {
-        ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
+        ab = ggml_scale(ctx, ab, scaling);
     }
     struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);
 
index 6a668d764905ca6dada39d30cc62ab449494fa58..7b1333a9de8887b5940e122b1d17d38f9b059036 100644 (file)
@@ -269,7 +269,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
     float rope_freq_scale = 1.0f;
     GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
     GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
-    GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+    GGUF_GET_KEY(ctx, rope_freq_scale,         gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
     if (rope_freq_scale != 1.0f) {
         hparams->rope_freq_scale = 1.0f / rope_freq_scale;
     }
@@ -612,6 +612,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
     const int n_rot       = hparams.n_embd_head();
     const int n_embd_head = hparams.n_embd_head();
     const int n_embd_gqa  = hparams.n_embd_gqa();
+
     const float rms_norm_eps    = hparams.f_norm_rms_eps;
     const float rope_freq_base  = hparams.rope_freq_base;
     const float rope_freq_scale = hparams.rope_freq_scale;
@@ -680,10 +681,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
         checkpoints.push_back(t01);
     }
 
-    struct ggml_tensor * kv_scale = NULL;
-    if (!enable_flash_attn) {
-        kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
-    }
+    const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
 
     for (int il = 0; il < n_layer; ++il) {
         struct my_llama_layer & layer = model->layers[il];
@@ -781,32 +779,32 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
     // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
     int n_leafs_before = gb->n_leafs;
     int n_nodes_before = gb->n_nodes;
-    struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
+
     // output tensors
-    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
-    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
     // input gradient
-    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
     GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
     ggml_allocr_alloc(alloc, t36->grad);
     // KQ_pos
-    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
 
     // make sure base model tensors data cannot be used in viewable operations
-    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
-    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
-    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f));
+    ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f));
     for (int il = 0; il < n_layer; ++il) {
         struct my_llama_layer & layer = model->layers[il];
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f));
     }
 
     // allocating checkpoints in one block to reduce memory fragmentation
index 1124659685034112da88bb7562b8e3d4e2c3d36f..f06ec400d9a6e5bcffb20752df3e518f0fc26729 100644 (file)
@@ -330,12 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
                               ggml_repeat(ctx0, model.pre_ln_b, embeddings));
     }
 
-    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-    ggml_allocr_alloc(ctx->alloc, KQ_scale);
-    if (!ggml_allocr_is_measure(ctx->alloc)) {
-        ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head));
-    }
-
     // loop over layers
     for (int il = 0; il < n_layer - 1; il++) {
         struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
@@ -356,7 +350,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
             struct ggml_tensor * Q =
                 ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur));
 
-            Q = ggml_scale_inplace(ctx0, Q, KQ_scale);
+            Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
             Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
             Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
             Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
index f7ed63365211b4f8a3eb14811d07a98ca595996d..4a9a2340b4cd4d1866a8025dbb25eecc339ce5fa 100644 (file)
@@ -369,10 +369,7 @@ static struct ggml_tensor * llama_build_train_graphs(
     checkpoints.push_back(t00);
     checkpoints.push_back(t01);
 
-    struct ggml_tensor * kv_scale = NULL;
-    if (!enable_flash_attn) {
-        kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
-    }
+    const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
 
     for (int il = 0; il < n_layer; ++il) {
         struct my_llama_layer & layer = model->layers[il];
@@ -444,14 +441,13 @@ static struct ggml_tensor * llama_build_train_graphs(
         // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
         int n_leafs_before = gb->n_leafs;
         int n_nodes_before = gb->n_nodes;
-        struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
         // output tensors
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
         // input gradient
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
         // KQ_pos
-        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
+        ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
         GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
 
         ggml_allocr_alloc(alloc, t36->grad);
index f5e060d32ccbd0f3ef81fd839299ca63e0fc9d85..ac91ee12e3428aea29022330cb7305c965fb493a 100644 (file)
@@ -7700,17 +7700,9 @@ inline void ggml_cuda_op_scale(
     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    float scale;
-    // HACK: support for ggml backend interface
-    if (src1->backend == GGML_BACKEND_CPU) {
-        scale = ((float *) src1->data)[0];
-    } else {
-        // TODO: pass pointer to kernel instead of copying to host
-        CUDA_CHECK(cudaMemcpy(&scale, src1->data, sizeof(float), cudaMemcpyDeviceToHost));
-    }
+    const float scale = ((float *) dst->op_params)[0];
 
     scale_f32_cuda(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
     CUDA_CHECK(cudaGetLastError());
@@ -7757,8 +7749,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
     const bool src1_on_device = use_src1 && src1->backend == GGML_BACKEND_GPU;
     const bool  dst_on_device =              dst->backend == GGML_BACKEND_GPU;
 
-    const bool src1_stays_on_host = use_src1 && dst->op == GGML_OP_SCALE;
-
     // dd = data device
     float * src0_ddf = nullptr;
     float * src1_ddf = nullptr;
@@ -7779,7 +7769,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
         CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
     }
 
-    if (use_src1 && !src1_stays_on_host) {
+    if (use_src1) {
         if (src1_on_device) {
             src1_ddf = (float *) src1_extra->data_device[g_main_device];
         } else {
index e60b93b36a7decb972f947f9885ffde187afa348..51a72ae335745008aeb1ead8da2221eacb352a3f 100644 (file)
@@ -1293,7 +1293,7 @@ void ggml_metal_graph_compute(
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
 
-                            const float scale = *(const float *) src1->data;
+                            const float scale = *(const float *) dst->op_params;
 
                             int64_t n = ggml_nelements(dst);
 
@@ -1304,8 +1304,8 @@ void ggml_metal_graph_compute(
                                 [encoder setComputePipelineState:ctx->pipeline_scale];
                             }
 
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                            [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
                             [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
diff --git a/ggml.c b/ggml.c
index 2361485148a15e77f2d559684e32c529d1fdbd61..f27920a2db0d5205544ef8757715af9592062a00 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4171,23 +4171,23 @@ struct ggml_tensor * ggml_out_prod(
 static struct ggml_tensor * ggml_scale_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b,
+        float                 s,
         bool inplace) {
-    GGML_ASSERT(ggml_is_scalar(b));
     GGML_ASSERT(ggml_is_padded_1d(a));
 
     bool is_node = false;
 
-    if (a->grad || b->grad) {
+    if (a->grad) {
         is_node = true;
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    ggml_set_op_params(result, &s, sizeof(s));
+
     result->op   = GGML_OP_SCALE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
-    result->src[1] = b;
 
     return result;
 }
@@ -4195,15 +4195,15 @@ static struct ggml_tensor * ggml_scale_impl(
 struct ggml_tensor * ggml_scale(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
-        struct ggml_tensor * b) {
-    return ggml_scale_impl(ctx, a, b, false);
+        float                s) {
+    return ggml_scale_impl(ctx, a, s, false);
 }
 
 struct ggml_tensor * ggml_scale_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor * a,
-        struct ggml_tensor * b) {
-    return ggml_scale_impl(ctx, a, b, true);
+        float                s) {
+    return ggml_scale_impl(ctx, a, s, true);
 }
 
 // ggml_set
@@ -10325,19 +10325,17 @@ static void ggml_compute_forward_out_prod(
 static void ggml_compute_forward_scale_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(ggml_is_scalar(src1));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
     // scale factor
-    const float v = *(float *) src1->data;
+    const float v = *(float *) dst->op_params;
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -10368,12 +10366,11 @@ static void ggml_compute_forward_scale_f32(
 static void ggml_compute_forward_scale(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_scale_f32(params, src0, src1, dst);
+                ggml_compute_forward_scale_f32(params, src0, dst);
             } break;
         default:
             {
@@ -14383,7 +14380,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SCALE:
             {
-                ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_scale(params, tensor->src[0], tensor);
             } break;
         case GGML_OP_SET:
             {
@@ -14839,7 +14836,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg
 
 static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) {
     if (ggml_hash_contains(zero_table, a)) {
-        struct ggml_tensor * a_zero = ggml_scale(ctx, a, ggml_new_f32(ctx, 0));
+        struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
         return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
     } else {
         return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
@@ -14975,7 +14972,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 src0->grad,
                                 ggml_scale(ctx,
                                     ggml_mul(ctx, src0, tensor->grad),
-                                    ggml_new_f32(ctx, 2.0f)),
+                                    2.0f),
                                 zero_table);
                 }
             } break;
@@ -14989,7 +14986,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                     ggml_div(ctx,
                                         tensor->grad,
                                         tensor),
-                                    ggml_new_f32(ctx, 0.5f)),
+                                    0.5f),
                                 zero_table);
                 }
             } break;
@@ -15155,17 +15152,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
+                    const float s = ((float *) tensor->op_params)[0];
+
                     src0->grad =
                         ggml_add_or_set(ctx,
                             src0->grad,
-                            ggml_scale_impl(ctx, tensor->grad, src1, false),
-                            zero_table);
-                }
-                if (src1->grad) {
-                    src1->grad =
-                        ggml_add_or_set(ctx,
-                            src1->grad,
-                            ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
+                            ggml_scale_impl(ctx, tensor->grad, s, false),
                             zero_table);
                 }
             } break;
diff --git a/ggml.h b/ggml.h
index b17314897b35c6d3b51e2aea5f1b5d1aeb055a6e..75918502bed2d481f8bd02d3e245bcb68b463d96 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1094,13 +1094,13 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_scale(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            float                 s);
 
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_scale_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            float                 s);
 
     // b -> view(a,offset,nb1,nb2,3), return modified a
     GGML_API struct ggml_tensor * ggml_set(
index ba970ce8d180916575b9900657cafb537579101c..d6c192441fbf04b5fd7eef22db07feb5ce3324d3 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -4032,13 +4032,12 @@ static struct ggml_tensor * llm_build_kqv(
          struct ggml_tensor * wo,
          struct ggml_tensor * wo_b,
          struct ggml_tensor * q_cur,
-         struct ggml_tensor * kq_scale,
          struct ggml_tensor * kq_mask,
                     int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   n_kv,
                     float     max_alibi_bias,
-                    float     scale,
+                    float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
     const int64_t n_embd      = hparams.n_embd;
@@ -4086,7 +4085,7 @@ static struct ggml_tensor * llm_build_kqv(
         kq = ggml_soft_max(ctx, kq);
         cb(kq, "kq_soft_max", il);
     } else {
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, scale);
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
         cb(kq, "kq_soft_max_ext", il);
     }
 
@@ -4231,10 +4230,6 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -4295,7 +4290,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4416,10 +4411,6 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -4478,7 +4469,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4536,10 +4527,6 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -4602,7 +4589,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4659,10 +4646,6 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -4702,7 +4685,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4759,10 +4742,6 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -4911,7 +4890,7 @@ struct llm_build_context {
                 // TODO: not tested, could be broken
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Q, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -4965,10 +4944,6 @@ struct llm_build_context {
         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
         cb(inpL, "inp_embd", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -5002,7 +4977,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5056,10 +5031,6 @@ struct llm_build_context {
         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
         cb(inpL, "inp_embd", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -5099,7 +5070,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5150,10 +5121,6 @@ struct llm_build_context {
         inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
         cb(inpL, "inp_embd", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -5193,7 +5160,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5253,10 +5220,6 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -5306,7 +5269,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5366,10 +5329,6 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -5423,7 +5382,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, NULL,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5482,14 +5441,6 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
         cb(inp_pos, "inp_pos", -1);
 
-        // Q_scale
-        struct ggml_tensor * Q_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(Q_scale, "Q_scale", -1);
-
-        // KQ_scale
-        struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
-        cb(KQ_scale, "KQ_scale", -1);
-
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
         cb(KQ_mask, "KQ_mask", -1);
@@ -5531,7 +5482,9 @@ struct llm_build_context {
                 );
                 cb(Qcur, "Qcur", il);
 
-                Qcur = ggml_scale(ctx0, Qcur, Q_scale);
+                // with phi2, we scale the Q to avoid precision issues
+                // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
+                Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head)));
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_custom(
@@ -5544,7 +5497,7 @@ struct llm_build_context {
 
                 cur = llm_build_kqv(ctx0, model, hparams, kv_self,
                         model.layers[il].wo, model.layers[il].bo,
-                        Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
+                        Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5681,8 +5634,6 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
     { "pos_embd",                   OFFLOAD_FUNC_NR  },
 
     { "inp_pos",                    OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
-    { "Q_scale",                    OFFLOAD_FUNC_NOP },
-    { "KQ_scale",                   OFFLOAD_FUNC_NOP },
     { "KQ_mask",                    OFFLOAD_FUNC_FRC },
     { "K_shift",                    OFFLOAD_FUNC_FRC },
 
@@ -5784,8 +5735,6 @@ static struct ggml_cgraph * llama_build_graph(
     bool alloc_inp_tokens   = false;
     bool alloc_inp_embd     = false;
     bool alloc_inp_pos      = false;
-    bool alloc_inp_Q_scale  = false;
-    bool alloc_inp_KQ_scale = false;
     bool alloc_inp_KQ_mask  = false;
     bool alloc_inp_K_shift  = false;
 
@@ -5849,37 +5798,6 @@ static struct ggml_cgraph * llama_build_graph(
             alloc_inp_pos = true;
         }
 
-        if (!alloc_inp_Q_scale && strcmp(name, "Q_scale") == 0) {
-            ggml_allocr_alloc(lctx.alloc, cur);
-
-            if (!ggml_allocr_is_measure(lctx.alloc)) {
-                const int64_t n_embd_head = model.hparams.n_embd_head();
-                float f = 1.0f/sqrtf(float(n_embd_head));
-                ggml_backend_tensor_set(cur, &f, 0, sizeof(f));
-            }
-
-            alloc_inp_Q_scale = true;
-        }
-
-        if (!alloc_inp_KQ_scale && strcmp(name, "KQ_scale") == 0) {
-            ggml_allocr_alloc(lctx.alloc, cur);
-
-            if (!ggml_allocr_is_measure(lctx.alloc)) {
-                const int64_t n_embd_head = model.hparams.n_embd_head();
-                float f;
-                if (model.arch == LLM_ARCH_PHI2) {
-                    // with phi2, we scale the Q to avoid precision issues
-                    // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66
-                    f = 1.0f;
-                } else {
-                    f = 1.0f/sqrtf(float(n_embd_head));
-                }
-                ggml_backend_tensor_set(cur, &f, 0, sizeof(f));
-            }
-
-            alloc_inp_KQ_scale = true;
-        }
-
         if (!alloc_inp_KQ_mask && strcmp(name, "KQ_mask") == 0) {
             ggml_allocr_alloc(lctx.alloc, cur);
 
@@ -9054,10 +8972,7 @@ static int llama_apply_lora_from_file_internal(
             ggml_set_name(BA, "BA");
 
             if (scaling != 1.0f) {
-                ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx.get(), scaling);
-                ggml_set_name(scale_tensor, "scale_tensor");
-
-                BA = ggml_scale_inplace(lora_ctx.get(), BA, scale_tensor);
+                BA = ggml_scale_inplace(lora_ctx.get(), BA, scaling);
                 offload_func(BA);
                 ggml_set_name(BA, "BA_scaled");
             }
index f04b9438a6194d2b934bf380f9ee188d102d39bf..f3df8a8c62a9a83961ed94234385d39de664b29a 100644 (file)
@@ -766,18 +766,19 @@ struct test_bin_bcast : public test_case {
 struct test_scale : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
+    float scale;
 
     std::string vars() override {
-        return VARS_TO_STR2(type, ne);
+        return VARS_TO_STR3(type, ne, scale);
     }
 
     test_scale(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10})
-        : type(type), ne(ne) {}
+            std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            float scale = 2.0f)
+        : type(type), ne(ne), scale(scale) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * scale = ggml_new_tensor_1d(ctx, type, 1);
         ggml_tensor * out = ggml_scale(ctx, a, scale);
         return out;
     }
index 81c20a89cb586b14cf5b264170f96e0d84f39d56..14914def565d910a2b100cdf747072a39d9047d8 100644 (file)
@@ -881,19 +881,19 @@ int main(int argc, const char ** argv) {
         // scale
         {
             srand(seed);
-            const int nargs = 2;
+            const int nargs = 1;
 
             int64_t ne2[4];
             ne2[0] = 1;
 
             for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
                 x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
 
+                const float s = -1.0f + 2.0f*frand();
+
                 ggml_set_param(ctx0, x[0]);
-                ggml_set_param(ctx0, x[1]);
 
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], x[1]));
+                struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s));
 
                 check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
             }
@@ -1395,7 +1395,7 @@ int main(int argc, const char ** argv) {
                                                 ggml_add1(ctx0,
                                                     ggml_scale(ctx0,
                                                         ggml_soft_max(ctx0, x[0]),
-                                                        ggml_new_f32(ctx0, 1.0f - eps)),
+                                                        1.0f - eps),
                                                     ggml_new_f32(ctx0, eps))));
 
                 check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);