]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : sync ggml (ggml_alibi)
authorGeorgi Gerganov <redacted>
Fri, 28 Apr 2023 17:37:43 +0000 (20:37 +0300)
committerGeorgi Gerganov <redacted>
Fri, 28 Apr 2023 17:51:05 +0000 (20:51 +0300)
ggml.c
ggml.h

diff --git a/ggml.c b/ggml.c
index 44293dac92668a2825125fa9f81e7f0f37ae3442..53796bd97d3abd68c36e7a38d23cce4ec5ac0d5d 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4034,7 +4034,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "MAP_BINARY",
 };
 
-static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
+static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -4082,7 +4082,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "f(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
+static_assert(GGML_OP_COUNT == 39, "GGML_OP_COUNT != 39");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -6080,6 +6080,37 @@ struct ggml_tensor * ggml_rope(
     return result;
 }
 
+// ggml_alibi
+
+struct ggml_tensor * ggml_alibi(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        int                   n_head) {
+    GGML_ASSERT(n_past >= 0);
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2);
+    ((int32_t *) b->data)[0] = n_past;
+    ((int32_t *) b->data)[1] = n_head;
+
+    result->op   = GGML_OP_ALIBI;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_conv_1d_1s
 
 struct ggml_tensor * ggml_conv_1d_1s(
@@ -9300,6 +9331,162 @@ static void ggml_compute_forward_soft_max(
     }
 }
 
+// ggml_compute_forward_alibi
+
+static void ggml_compute_forward_alibi_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 2);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_head = ((int32_t *) src1->data)[1];
+
+    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
+    const int ne1 = src0->ne[1]; // seq_len_without_past
+    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    //const int ne3 = src0->ne[3]; // 1 -> bsz
+
+    const int n  = ggml_nrows(src0);
+    const int ne2_ne3 = n/ne1; // ne2*ne3
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    //const int nb3 = src0->nb[3];
+
+    assert(nb0 == sizeof(float));
+    assert(ne1+n_past == ne0);
+
+    // add alibi to src0 (KQ_scaled)
+    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
+
+    for (int i = 0; i < ne0; i++) {
+        for (int j = 0; j < ne1; j++) {
+            for (int k = 0; k < ne2_ne3; k++) {
+                float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
+                float *      pdst = (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
+
+                // TODO: k*nb2 or k*nb3
+
+                float m_k;
+
+                if (k < n_heads_log2_floor) {
+                    m_k = powf(m0, k + 1);
+                } else {
+                    m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+                }
+
+                pdst[0] = (j+1) * m_k + src[0];
+            }
+        }
+    }
+}
+
+
+static void ggml_compute_forward_alibi_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 2);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int n_past = ((int32_t *) src1->data)[0];
+    const int n_head = ((int32_t *) src1->data)[1];
+
+    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
+    const int ne1 = src0->ne[1]; // seq_len_without_past
+    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    //const int ne3 = src0->ne[3]; // 1 -> bsz
+
+    const int n  = ggml_nrows(src0);
+    const int ne2_ne3 = n/ne1; // ne2*ne3
+
+    const int nb0 = src0->nb[0];
+    const int nb1 = src0->nb[1];
+    const int nb2 = src0->nb[2];
+    //const int nb3 = src0->nb[3];
+
+    assert(nb0 == sizeof(ggml_fp16_t));
+    assert(ne1+n_past == ne0);
+
+    // add alibi to src0 (KQ_scaled)
+    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
+
+    for (int i = 0; i < ne0; i++) {
+        for (int j = 0; j < ne1; j++) {
+            for (int k = 0; k < ne2_ne3; k++) {
+                ggml_fp16_t * const src  = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
+                      float *      pdst  =       (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
+
+                // TODO: k*nb2 or k*nb3
+
+                float m_k;
+
+                if (k < n_heads_log2_floor) {
+                    m_k = powf(m0, k + 1);
+                } else {
+                    m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+                }
+
+                // we return F32
+                pdst[0] = (j+1) * m_k + GGML_FP16_TO_FP32(src[0]);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_alibi(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_alibi_f16(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_alibi_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q4_3:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_rope
 
 static void ggml_compute_forward_rope_f32(
@@ -10938,6 +11125,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_ALIBI:
+            {
+                ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_CONV_1D_1S:
             {
                 ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
@@ -11140,6 +11331,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_ALIBI:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_SILU:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -11673,6 +11868,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = n_threads;
                     } break;
+                case GGML_OP_ALIBI:
+                    {
+                        node->n_tasks = 1; //TODO
+                    } break;
                 case GGML_OP_CONV_1D_1S:
                 case GGML_OP_CONV_1D_2S:
                     {
diff --git a/ggml.h b/ggml.h
index 1bbe2db93f5d1e6f79e29e03d8d293a1bdada76f..540901f15a7f1e58e7a830cb5526efb6e1d124b3 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -269,6 +269,7 @@ extern "C" {
         GGML_OP_DIAG_MASK_INF,
         GGML_OP_SOFT_MAX,
         GGML_OP_ROPE,
+        GGML_OP_ALIBI,
         GGML_OP_CONV_1D_1S,
         GGML_OP_CONV_1D_2S,
 
@@ -662,6 +663,14 @@ extern "C" {
             int                   n_dims,
             int                   mode);
 
+    // alibi position embedding
+    // in-place, returns view(a)
+    struct ggml_tensor * ggml_alibi(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_head);
+
     // padding = 1
     // TODO: we don't support extra parameters for now
     //       that's why we are hard-coding the stride, padding, and dilation