]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : various fixes (#1450)
authorGeorgi Gerganov <redacted>
Sun, 14 May 2023 15:22:50 +0000 (18:22 +0300)
committerGitHub <redacted>
Sun, 14 May 2023 15:22:50 +0000 (18:22 +0300)
- `ggml_rope()`
- `ggml_diag_mask_inf()` multi-threaded
- compatibility with scratch buffers

ggml.c
ggml.h

diff --git a/ggml.c b/ggml.c
index 8ef1bb24498d189157179ebf49f214280d25c427..da3d914e4ef47837fb5455676e61e95a3aef2579 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -3923,6 +3923,20 @@ size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch)
     return result;
 }
 
+// IMPORTANT:
+// when creating "opt" tensors, always save and load the scratch buffer
+// this is an error prone process, but it is necessary to support inplace
+// operators when using scratch buffers
+// TODO: implement a better way
+void ggml_scratch_save(struct ggml_context * ctx) {
+    ctx->scratch_save = ctx->scratch;
+    ctx->scratch.data = NULL;
+}
+
+void ggml_scratch_load(struct ggml_context * ctx) {
+    ctx->scratch = ctx->scratch_save;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_tensor * ggml_new_tensor_impl(
@@ -4094,12 +4108,11 @@ struct ggml_tensor * ggml_new_tensor_4d(
 }
 
 struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
-    ctx->scratch_save = ctx->scratch;
-    ctx->scratch.data = NULL;
+    ggml_scratch_save(ctx);
 
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
 
-    ctx->scratch = ctx->scratch_save;
+    ggml_scratch_load(ctx);
 
     ggml_set_i32(result, value);
 
@@ -4107,12 +4120,11 @@ struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
 }
 
 struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
-    ctx->scratch_save = ctx->scratch;
-    ctx->scratch.data = NULL;
+    ggml_scratch_save(ctx);
 
     struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
 
-    ctx->scratch = ctx->scratch_save;
+    ggml_scratch_load(ctx);
 
     ggml_set_f32(result, value);
 
@@ -4541,13 +4553,19 @@ struct ggml_tensor * ggml_acc_impl(
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
+
     ((int32_t *) c->data)[0] = nb1;
     ((int32_t *) c->data)[1] = nb2;
     ((int32_t *) c->data)[2] = nb3;
     ((int32_t *) c->data)[3] = offset;
     ((int32_t *) c->data)[4] = inplace ? 1 : 0;
 
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_ACC;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
@@ -5344,13 +5362,19 @@ struct ggml_tensor * ggml_set_impl(
 
     // make a view of the destination
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5);
+
     (( int32_t * ) c->data)[0] = nb1;
     (( int32_t * ) c->data)[1] = nb2;
     (( int32_t * ) c->data)[2] = nb3;
     (( int32_t * ) c->data)[3] = offset;
     (( int32_t * ) c->data)[4] = inplace ? 1 : 0;
 
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_SET;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
@@ -5954,10 +5978,16 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = inplace ? 1 : 0;
 
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_DIAG_MASK_INF;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
@@ -5995,11 +6025,17 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
     ggml_set_name(b, "n_past, inplace");
+
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = inplace ? 1 : 0;
 
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_DIAG_MASK_ZERO;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
@@ -6074,11 +6110,16 @@ struct ggml_tensor * ggml_rope_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
 
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_ROPE;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
@@ -6123,11 +6164,16 @@ struct ggml_tensor * ggml_rope_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+    ggml_set_name(b, "n_past, n_dims, mode");
+
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_dims;
     ((int32_t *) b->data)[2] = mode;
-    ggml_set_name(b, "n_past, n_dims, mode");
+
+    ggml_scratch_load(ctx);
 
     result->op   = GGML_OP_ROPE_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -6156,10 +6202,15 @@ struct ggml_tensor * ggml_alibi(
     //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_head;
 
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_ALIBI;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
@@ -10450,19 +10501,33 @@ static void ggml_compute_forward_diag_mask_f32(
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 2);
 
-    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+    const int  n_past  =       ((int32_t *) src1->data)[0];
+    const bool inplace = (bool)((int32_t *) src1->data)[1];
+
+    if (params->type == GGML_TASK_INIT) {
+        // TODO: this hack is not good, need a better way to handle this
+        if (!inplace) {
+            // use the init task to copy src -> dst
+            struct ggml_compute_params params_cpy = *params;
+
+            params_cpy.ith  = 0;
+            params_cpy.nth  = 1;
+            params_cpy.type = GGML_TASK_COMPUTE;
+
+            ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
+        }
+
+        return;
+    }
+
+    if (params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int  n_past  =       ((int32_t *) src1->data)[0];
-    const bool inplace = (bool)((int32_t *) src1->data)[1];
-
-    if (!inplace) {
-        ggml_compute_forward_dup_same_cont(params, src0, dst);
-    }
+    assert(n_past >= 0);
 
     // TODO: handle transposed/permuted matrices
 
@@ -10550,7 +10615,7 @@ static void ggml_compute_forward_soft_max_f32(
 
     for (int i1 = ir0; i1 < ir1; i1++) {
         float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float *dp = (float *)((char *) dst->data + i1*dst->nb[1]);
+        float *dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -10626,6 +10691,8 @@ static void ggml_compute_forward_alibi_f32(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_head = ((int32_t *) src1->data)[1];
 
+    assert(n_past >= 0);
+
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
     //const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -10687,6 +10754,8 @@ static void ggml_compute_forward_alibi_f16(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_head = ((int32_t *) src1->data)[1];
 
+    assert(n_past >= 0);
+
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
     //const int ne2 = src0->ne[2]; // n_head -> this is k
@@ -10780,28 +10849,34 @@ static void ggml_compute_forward_rope_f32(
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    assert(n_past >= 0);
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
 
-    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(dst);
 
-    GGML_ASSERT(n_dims <= nc);
+    GGML_ASSERT(n_dims <= ne0);
     GGML_ASSERT(n_dims % 2 == 0);
 
     // rows per thread
@@ -10820,37 +10895,50 @@ static void ggml_compute_forward_rope_f32(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = src[0];
                         const float x1 = src[1];
 
                         dst_data[0] = x0*cos_theta - x1*sin_theta;
                         dst_data[1] = x0*sin_theta + x1*cos_theta;
-                    } else {
-                        const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float x0 = src[0];
-                        const float x1 = src[n_dims/2];
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
 
-                        dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  float * dst_data  = (float *)((char *)  dst->data +  i3*nb3 + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims/2];
+
+                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                        }
                     }
                 }
             }
@@ -10874,15 +10962,22 @@ static void ggml_compute_forward_rope_f16(
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    assert(n_past >= 0);
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -10892,10 +10987,9 @@ static void ggml_compute_forward_rope_f16(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
-    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(dst);
 
-    GGML_ASSERT(n_dims <= nc);
+    GGML_ASSERT(n_dims <= ne0);
     GGML_ASSERT(n_dims % 2 == 0);
 
     // rows per thread
@@ -10914,37 +11008,50 @@ static void ggml_compute_forward_rope_f16(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float x0 = GGML_FP16_TO_FP32(src[0]);
                         const float x1 = GGML_FP16_TO_FP32(src[1]);
 
                         dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
                         dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    } else {
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    // TODO: this is probably wrong, but I can't figure it out ..
+                    // ref:  https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float x0 = GGML_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
 
-                        dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+                            dst_data[0]     = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
                     }
                 }
             }
@@ -10995,15 +11102,23 @@ static void ggml_compute_forward_rope_back_f32(
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    assert(n_past >= 0);
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -11013,7 +11128,7 @@ static void ggml_compute_forward_rope_back_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
+    const int nr = ggml_nrows(dst);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -11031,37 +11146,48 @@ static void ggml_compute_forward_rope_back_f32(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const float * const dy  = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              float *       dx  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float dy0 = dy[0];
                         const float dy1 = dy[1];
 
                         dx[0] =   dy0*cos_theta + dy1*sin_theta;
                         dx[1] = - dy0*sin_theta + dy1*cos_theta;
-                    } else {
-                        const float * const dy  = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              float *       dx  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float dy0 = dy[0];
-                        const float dy1 = dy[n_dims/2];
+                            theta *= theta_scale;
 
-                        dx[0]        =   dy0*cos_theta + dy1*sin_theta;
-                        dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
+                            const int64_t i0 = ib*n_dims + ic/2;
+
+                            const float * const dy  = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  float *       dx  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                            const float dy0 = dy[0];
+                            const float dy1 = dy[n_dims/2];
+
+                            dx[0]        =   dy0*cos_theta + dy1*sin_theta;
+                            dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
+                        }
                     }
                 }
             }
@@ -11089,15 +11215,23 @@ static void ggml_compute_forward_rope_back_f16(
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
 
-    //const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
-    const int64_t ne3 = src0->ne[3];
+    assert(n_past >= 0);
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+    const size_t nb02 = src0->nb[2];
+    const size_t nb03 = src0->nb[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const int64_t ne3 = dst->ne[3];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+    const size_t nb2 = dst->nb[2];
+    const size_t nb3 = dst->nb[3];
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    const int nb3 = src0->nb[3];
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
@@ -11107,7 +11241,7 @@ static void ggml_compute_forward_rope_back_f16(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr = ggml_nrows(src0);
+    const int nr = ggml_nrows(dst);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -11125,37 +11259,48 @@ static void ggml_compute_forward_rope_back_f16(
 
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
-            const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
+            const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
             for (int64_t i1 = 0; i1 < ne1; i1++) {
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
                 float theta = (float)p;
 
-                for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const float cos_theta = cosf(theta);
-                    const float sin_theta = sinf(theta);
+                if (!is_neox) {
+                    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+                        const float cos_theta = cosf(theta);
+                        const float sin_theta = sinf(theta);
 
-                    theta *= theta_scale;
+                        theta *= theta_scale;
 
-                    if (!is_neox) {
-                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         const float dy0 = GGML_FP16_TO_FP32(dy[0]);
                         const float dy1 = GGML_FP16_TO_FP32(dy[1]);
 
                         dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
                         dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
-                    } else {
-                        const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
-                              ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
+                    }
+                } else {
+                    for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
+                        for (int64_t ic = 0; ic < n_dims; ic += 2) {
+                            const float cos_theta = cosf(theta);
+                            const float sin_theta = sinf(theta);
 
-                        const float dy0 = GGML_FP16_TO_FP32(dy[0]);
-                        const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
+                            theta *= theta_scale;
+
+                            const int64_t i0 = ib*n_dims + ic/2;
+
+                            const ggml_fp16_t * const dy  = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                                  ggml_fp16_t *       dx  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-                        dx[0]        = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
-                        dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
+                            const float dy0 = GGML_FP16_TO_FP32(dy[0]);
+                            const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
+
+                            dx[0]        = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
+                            dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
+                        }
                     }
                 }
             }
diff --git a/ggml.h b/ggml.h
index 3b045ad4f1106c528ff71aa2a421d68b52bd5adc..255541d0257e3abea7461e6c8a84a4cb5f1fe583 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -340,7 +340,7 @@ extern "C" {
 
     // n-dimensional tensor
     struct ggml_tensor {
-        enum ggml_type type;
+        enum ggml_type    type;
         enum ggml_backend backend;
 
         int     n_dims;
@@ -372,7 +372,7 @@ extern "C" {
 
         char name[32];
 
-        char padding[9]; // TODO: remove and add padding to name?
+        char padding[16];
     };
 
     // computation graph