]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : fix rope + add tests (llama/7452)
authorGeorgi Gerganov <redacted>
Wed, 22 May 2024 08:01:35 +0000 (11:01 +0300)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
* cuda : fix rope pos data

ggml-ci

* ggml : drop mode & 1 == 1 support for ggml_rope

ggml-ci

* ggml : support freq_factors for f16 rope (CPU)

ggml-ci

* tests : add rope tests using frequency factors

ggml-ci

include/ggml/ggml.h
src/ggml-cuda/rope.cu
src/ggml.c
tests/test-backend-ops.cpp

index 35ac9110ceb17034602c2c79c8404c65ed36782c..08835042c0bfdbaa0f1f003760d80612c4a4fcfd 100644 (file)
@@ -1460,7 +1460,7 @@ extern "C" {
             struct ggml_tensor  * b);
 
     // rotary position embedding
-    // if mode & 1 == 1, skip n_past elements (DEPRECATED)
+    // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
     // if mode & 2 == 1, GPT-NeoX style
     // if mode & 4 == 1, ChatGLM style
     //
index 4a558f4b3757e49dd3b2849a3fd4be6d8300ac55..50f2cf415ef6098be04202e1b0a30c37d8e4494d 100644 (file)
@@ -283,9 +283,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
-    if (is_neox) {
-        pos = (const int32_t *) src1_d;
+    pos = (const int32_t *) src1_d;
 
+    if (is_neox) {
         if (src2 != nullptr) {
             freq_factors = (const float *) src2->data;
         }
index 37b16b7a9ce7f2bd7a0dd323d41c5c5b6a787d6e..d316e3d316806516ba44ec86e0a2d0162e08da4e 100644 (file)
@@ -6245,6 +6245,8 @@ static struct ggml_tensor * ggml_rope_impl(
         float                 xpos_base,
         bool                  xpos_down,
         bool                  inplace) {
+    GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
+
     GGML_ASSERT(ggml_is_vector(b));
     GGML_ASSERT(b->type == GGML_TYPE_I32);
     GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -14413,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
             freq_factors = (const float *) src2->data;
         }
     } else {
-        GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
+        GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
     }
 
     // backward process uses inverse rotation by cos and sin.
@@ -14529,6 +14531,7 @@ static void ggml_compute_forward_rope_f32(
     }
 }
 
+// TODO: deduplicate f16/f32 code
 static void ggml_compute_forward_rope_f16(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst,
@@ -14536,6 +14539,7 @@ static void ggml_compute_forward_rope_f16(
 
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src2 = dst->src[2];
 
     if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
         return;
@@ -14588,6 +14592,17 @@ static void ggml_compute_forward_rope_f16(
     const bool is_neox = mode & 2;
     const bool is_glm  = mode & 4;
 
+    const float * freq_factors = NULL;
+    if (is_neox) {
+        if (src2 != NULL) {
+            GGML_ASSERT(src2->type == GGML_TYPE_F32);
+            GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+            freq_factors = (const float *) src2->data;
+        }
+    } else {
+        GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
+    }
+
     // backward process uses inverse rotation by cos and sin.
     // cos and sin build a rotation matrix, where the inverse is the transpose.
     // this essentially just switches the sign of sin.
@@ -14660,10 +14675,11 @@ static void ggml_compute_forward_rope_f16(
 
                             // simplified from `(ib * n_dims + ic) * inv_ndims`
                             float cur_rot = inv_ndims * ic - ib;
+                            float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
 
                             float cos_theta, sin_theta;
                             rope_yarn(
-                                theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
                                 &cos_theta, &sin_theta
                             );
                             sin_theta *= sin_sign;
index 1493a7ca7c405e4f125325dcb110d28df2f03015..de74585da29dd40905f4b2cc9a0aec176e5c5358 100644 (file)
@@ -1142,20 +1142,22 @@ struct test_rope : public test_case {
     int n_dims;
     int mode;
     int n_ctx;
+    bool ff;
 
     std::string vars() override {
-        return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx);
+        return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
     }
 
     test_rope(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 10, 1},
-            int n_dims = 10, int mode = 0, int n_ctx = 512)
-        : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
+            int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
+        : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
-        ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
+        ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
+        ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
         return out;
     }
 
@@ -1169,7 +1171,12 @@ struct test_rope : public test_case {
                 }
                 ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
             } else {
-                init_tensor_uniform(t);
+                if (t->ne[0] == n_dims/2) {
+                    // frequency factors in the range [0.9f, 1.1f]
+                    init_tensor_uniform(t, 0.9f, 1.1f);
+                } else {
+                    init_tensor_uniform(t);
+                }
             }
         }
     }
@@ -2188,16 +2195,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 8.0f));
 
     for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-        test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512)); // llama 7B
-        test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512)); // llama 13B
-        test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512)); // llama 30B
-        test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512)); // llama 65B
-        test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512)); // neox (falcon 7B)
-        test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512)); // neox (falcon 7B)
-        test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512)); // neox (falcon 40B)
-        test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512)); // neox (falcon 40B)
-        test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512)); // neox (stablelm)
-        test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  32, 2, 512)); // neox (phi-2)
+        // TODO: ff not supported yet for !neox
+        test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, false)); // llama 7B
+        test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, false)); // llama 13B
+        test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, false)); // llama 30B
+        test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, false)); // llama 65B
+
+        for (bool ff : {false, true}) { // freq_factors
+            test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512, ff)); // neox (falcon 7B)
+            test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512, ff)); // neox (falcon 7B)
+            test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512, ff)); // neox (falcon 40B)
+            test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512, ff)); // neox (falcon 40B)
+            test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512, ff)); // neox (stablelm)
+            test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  32, 2, 512, ff)); // neox (phi-2)
+        }
     }
 
     test_cases.emplace_back(new test_concat(GGML_TYPE_F32));