]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix YARN + add tests + add asserts (llama/7617)
authorGeorgi Gerganov <redacted>
Wed, 29 May 2024 17:17:31 +0000 (20:17 +0300)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
* tests : add rope tests

ggml-ci

* ggml : fixes (hopefully)

ggml-ci

* tests : add non-cont tests

ggml-ci

* cuda : add asserts for rope/norm + fix DS2

ggml-ci

* ggml : assert contiguousness

* tests : reduce RoPE tests

ggml-ci

include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-cuda/norm.cu
src/ggml-cuda/rope.cu
src/ggml-kompute.cpp
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-sycl.cpp
src/ggml.c
tests/test-backend-ops.cpp

index f9deac7e8054e628fd5a46448a5d1c53ee1b1533..f38699698b1e9f6e9f8870e3bb968342e3917e81 100644 (file)
@@ -756,7 +756,6 @@ extern "C" {
     GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
 
     GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
-    GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
     GGML_API GGML_CALL bool ggml_is_permuted  (const struct ggml_tensor * tensor);
     GGML_API GGML_CALL bool ggml_is_empty     (const struct ggml_tensor * tensor);
     GGML_API           bool ggml_is_scalar    (const struct ggml_tensor * tensor);
@@ -765,6 +764,11 @@ extern "C" {
     GGML_API           bool ggml_is_3d        (const struct ggml_tensor * tensor);
     GGML_API           int  ggml_n_dims       (const struct ggml_tensor * tensor); // returns 1 for scalars
 
+    GGML_API GGML_CALL bool ggml_is_contiguous  (const struct ggml_tensor * tensor);
+    GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
+    GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
+    GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
+
     GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
     GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 
index d0a754ee11b67754e6f0f9fcf8dbfe2ad3a32fa9..1172f7b2f8cafda87b915f875abe2187b3982015 100644 (file)
@@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         }
     }
 #else
-    if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
+    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
         CUBLAS_CHECK(
@@ -2886,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
+            return true;
         case GGML_OP_ROPE:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM_ROWS:
index 86f77453449940d0757e10cd46ee7ca2f9a4eba5..30866d51274fb2c6a7f0db4301d216dff2320464 100644 (file)
@@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
index 50f2cf415ef6098be04202e1b0a30c37d8e4494d..0dd07977ebab1e263db0419c67da9e7ae0c79cd3 100644 (file)
@@ -61,7 +61,7 @@ static __global__ void rope(
 template<typename T, bool has_pos, bool has_freq_facs>
 static __global__ void rope_neox(
     const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
 ) {
     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
@@ -85,15 +85,13 @@ static __global__ void rope_neox(
     const int i  = row*ncols + ib*n_dims + ic/2;
     const int i2 = row/p_delta_rows;
 
-    float cur_rot = inv_ndims * ic - ib;
-
     const int p = has_pos ? pos[i2] : 0;
     const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
 
-    const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
+    const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
 
     float cos_theta, sin_theta;
-    rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
     const float x0 = x[i + 0];
     const float x1 = x[i + n_dims/2];
@@ -174,30 +172,29 @@ static void rope_neox_cuda(
     const dim3 block_nums(nrows, num_blocks_x, 1);
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
-    const float inv_ndims = -1.0f / n_dims;
 
     if (pos == nullptr) {
         if (freq_factors == nullptr) {
             rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
                 x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, inv_ndims, freq_factors
+                theta_scale, freq_factors
                 );
         } else {
             rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
                 x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, inv_ndims, freq_factors
+                theta_scale, freq_factors
                 );
         }
     } else {
         if (freq_factors == nullptr) {
             rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
                 x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, inv_ndims, freq_factors
+                theta_scale, freq_factors
                 );
         } else {
             rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
                 x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, inv_ndims, freq_factors
+                theta_scale, freq_factors
                 );
         }
     }
@@ -254,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
+    GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
     GGML_ASSERT(src0->type == dst->type);
index 6c6058b2a95b12f3eed9237bdf698e32460eb7fd..ed59d2be64091fcef6ee46034d39d98c5ea3e45f 100644 (file)
@@ -1597,7 +1597,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                     {
                         GGML_ASSERT(ne00 == ne10);
 
-                        // TODO: assert that dim2 and dim3 are contiguous
+                        ggml_is_contiguous_2(src0);
+                        ggml_is_contiguous_2(src1);
+
                         GGML_ASSERT(ne12 % ne02 == 0);
                         GGML_ASSERT(ne13 % ne03 == 0);
 
index 4ba498e87f9d04064766502dd61ce9a2c5f17eef..a7e13bdcfe07f1c363cf4c1049efdfe9bf8ee48f 100644 (file)
@@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute(
                     {
                         GGML_ASSERT(ne00 == ne10);
 
-                        // TODO: assert that dim2 and dim3 are contiguous
+                        ggml_is_contiguous_2(src0);
+                        ggml_is_contiguous_2(src1);
+
                         GGML_ASSERT(ne12 % ne02 == 0);
                         GGML_ASSERT(ne13 % ne03 == 0);
 
@@ -2187,6 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute(
                 case GGML_OP_RMS_NORM:
                     {
                         GGML_ASSERT(ne00 % 4 == 0);
+                        GGML_ASSERT(ggml_is_contiguous_1(src0));
 
                         float eps;
                         memcpy(&eps, dst->op_params, sizeof(float));
@@ -2214,6 +2217,7 @@ static enum ggml_status ggml_metal_graph_compute(
                 case GGML_OP_GROUP_NORM:
                     {
                         GGML_ASSERT(ne00 % 4 == 0);
+                        GGML_ASSERT(ggml_is_contiguous(src0));
 
                         //float eps;
                         //memcpy(&eps, dst->op_params, sizeof(float));
@@ -2247,6 +2251,8 @@ static enum ggml_status ggml_metal_graph_compute(
                     } break;
                 case GGML_OP_NORM:
                     {
+                        GGML_ASSERT(ggml_is_contiguous_1(src0));
+
                         float eps;
                         memcpy(&eps, dst->op_params, sizeof(float));
 
index b16f2b7e0c74f13456632c906f52d9e62a0d7fb4..0cb85e1a5bad4daeabe6e1de00656402556e9f8b 100644 (file)
@@ -1767,13 +1767,13 @@ kernel void kernel_rope(
 
     const int64_t p = pos[i2];
 
-    const float theta_0 = (float)p;
+    const float theta_base = (float)p;
     const float inv_ndims = -1.f/n_dims;
 
     if (!is_neox) {
         for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
 
-            const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
             float cos_theta, sin_theta;
             rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
@@ -1789,18 +1789,14 @@ kernel void kernel_rope(
     } else {
         for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
             if (ic < n_dims) {
-                const int64_t ib = 0;
+                const int64_t i0 = ic/2;
 
-                // simplified from `(ib * n_dims + ic) * inv_ndims`
-                const float cur_rot = inv_ndims*ic - ib;
-                const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
+                const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
 
-                const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
+                const float theta = theta_base * pow(freq_base, inv_ndims*ic);
 
                 float cos_theta, sin_theta;
-                rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
-
-                const int64_t i0 = ib*n_dims + ic/2;
+                rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
                 device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                 device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
index a73448136a4d8d90b253b283e487d7f6e754281b..5cd97e4ff98df55983c43be14025dea07c609ea8 100644 (file)
@@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
     const int64_t r2 = ne12/ne02;
     const int64_t r3 = ne13/ne03;
 
-    if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
+    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
             *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
index e6e2397b7848b68ad5ca658d7db822c4387bdfd0..b2b725f65452c8be951817efdd3b7aa167b50015 100644 (file)
@@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
-static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
+GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
+    return ggml_is_contiguous(tensor);
+}
+
+GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
@@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
+GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        tensor->nb[0] == ggml_type_size(tensor->type) &&
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
 GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
@@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32(
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * grad = dst->src[1];
 
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
-    GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
+    GGML_ASSERT(ggml_is_contiguous_1(grad));
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_are_same_shape(src0, grad));
 
@@ -14358,7 +14370,7 @@ static void ggml_compute_forward_rope_f32(
     int ir = 0;
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
-    const float inv_ndims = -1.f/n_dims;
+
     float corr_dims[2];
     ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
@@ -14407,7 +14419,7 @@ static void ggml_compute_forward_rope_f32(
                         const float cos_block_theta = cosf(block_theta);
                         const float sin_block_theta = sinf(block_theta) * sin_sign;
 
-                        theta_base *= theta_scale;
+                        theta_base  *= theta_scale;
                         block_theta *= theta_scale;
 
                         const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14442,29 +14454,22 @@ static void ggml_compute_forward_rope_f32(
                         dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
                     }
                 } else {
-                    // TODO: this might be wrong for ne0 != n_dims - need double check
-                    //       it seems we have to rope just the first n_dims elements and do nothing with the rest
-                    // ref:  https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
-                    theta_base *= freq_scale;
+                    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
                     for (int64_t ic = 0; ic < ne0; ic += 2) {
                         if (ic < n_dims) {
-                            const int64_t ib = 0;
+                            const int64_t i0 = ic/2;
 
-                            // simplified from `(ib * n_dims + ic) * inv_ndims`
-                            float cur_rot = inv_ndims * ic - ib;
-                            float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
+                            const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
 
                             float cos_theta, sin_theta;
                             rope_yarn(
-                                theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
                                 &cos_theta, &sin_theta
                             );
-                            sin_theta *= sin_sign;
 
+                            sin_theta  *= sin_sign;
                             theta_base *= theta_scale;
 
-                            const int64_t i0 = ib*n_dims + ic/2;
-
                             const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                                   float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
@@ -14543,7 +14548,7 @@ static void ggml_compute_forward_rope_f16(
     int ir = 0;
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
-    const float inv_ndims = -1.f/n_dims;
+
     float corr_dims[2];
     ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
 
@@ -14592,7 +14597,7 @@ static void ggml_compute_forward_rope_f16(
                         const float cos_block_theta = cosf(block_theta);
                         const float sin_block_theta = sinf(block_theta) * sin_sign;
 
-                        theta_base *= theta_scale;
+                        theta_base  *= theta_scale;
                         block_theta *= theta_scale;
 
                         const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14623,29 +14628,22 @@ static void ggml_compute_forward_rope_f16(
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
                     }
                 } else {
-                    // TODO: this might be wrong for ne0 != n_dims - need double check
-                    //       it seems we have to rope just the first n_dims elements and do nothing with the rest
-                    // ref:  https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
-                    theta_base *= freq_scale;
+                    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
                     for (int64_t ic = 0; ic < ne0; ic += 2) {
                         if (ic < n_dims) {
-                            const int64_t ib = 0;
+                            const int64_t i0 = ic/2;
 
-                            // simplified from `(ib * n_dims + ic) * inv_ndims`
-                            float cur_rot = inv_ndims * ic - ib;
-                            float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
+                            const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
 
                             float cos_theta, sin_theta;
                             rope_yarn(
-                                theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+                                theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
                                 &cos_theta, &sin_theta
                             );
-                            sin_theta *= sin_sign;
 
+                            sin_theta  *= sin_sign;
                             theta_base *= theta_scale;
 
-                            const int64_t i0 = ib*n_dims + ic/2;
-
                             const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
                                   ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
index 5cde21c660514469f81113ee8145ec5243c71529..72edc64a72c2caba90a0606590865beb5610896e 100644 (file)
@@ -1138,26 +1138,37 @@ struct test_soft_max : public test_case {
 // GGML_OP_ROPE
 struct test_rope : public test_case {
     const ggml_type type;
-    const std::array<int64_t, 4> ne;
+    const std::array<int64_t, 4> ne_a;
     int n_dims;
     int mode;
     int n_ctx;
+    float fs; // freq_scale
+    float ef; // ext_factor
+    float af; // attn_factor
     bool ff;
+    int v; // view (1 : non-contiguous a)
 
     std::string vars() override {
-        return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
+        return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
     }
 
     test_rope(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 1},
-            int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
-        : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
+            std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
+            int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0)
+        : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
+        ggml_tensor * a;
+        if (v & 1) {
+            auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
+            a = ggml_new_tensor(ctx, type, 4, ne.data());
+            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+        } else {
+            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+        }
+        ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
         ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
-        ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
+        ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
         return out;
     }
 
@@ -1165,11 +1176,11 @@ struct test_rope : public test_case {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (t->type == GGML_TYPE_I32) {
                 // pos
-                std::vector<int> data(ne[2]);
-                for (int i = 0; i < ne[2]; i++) {
+                std::vector<int> data(ne_a[2]);
+                for (int i = 0; i < ne_a[2]; i++) {
                     data[i] = rand() % n_ctx;
                 }
-                ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
+                ggml_backend_tensor_set(t, data.data(), 0, ne_a[2] * sizeof(int));
             } else {
                 if (t->ne[0] == n_dims/2) {
                     // frequency factors in the range [0.9f, 1.1f]
@@ -2213,20 +2224,38 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 0.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 8.0f));
 
-    for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-        // TODO: ff not supported yet for !neox
-        test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, false)); // llama 7B
-        test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, false)); // llama 13B
-        test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, false)); // llama 30B
-        test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, false)); // llama 65B
-
-        for (bool ff : {false, true}) { // freq_factors
-            test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512, ff)); // neox (falcon 7B)
-            test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512, ff)); // neox (falcon 7B)
-            test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512, ff)); // neox (falcon 40B)
-            test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512, ff)); // neox (falcon 40B)
-            test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512, ff)); // neox (stablelm)
-            test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  32, 2, 512, ff)); // neox (phi-2)
+    {
+        bool all = true;
+
+        for (float v : { 0, 1 }) {
+            for (float fs : { 1.0f, 1.4245f }) {
+                for (float ef : { 0.0f, 0.7465f }) {
+                    for (float af : { 1.0f, 1.4245f }) {
+                        for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+                            // TODO: ff not supported yet for !neox
+                            test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
+                            if (all) {
+                                test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
+                                test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
+                                test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
+                            }
+
+                            for (bool ff : {false, true}) { // freq_factors
+                                if (all) {
+                                    test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
+                                }
+
+                                test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
+                            }
+                        }
+                        all = false;
+                    }
+                }
+            }
         }
     }