]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix Alibi implementation (#351)
authorDaulet Zhanguzin <redacted>
Tue, 11 Jul 2023 17:26:22 +0000 (10:26 -0700)
committerGitHub <redacted>
Tue, 11 Jul 2023 17:26:22 +0000 (20:26 +0300)
* correct Alibi implementation

* update f16 too

src/ggml.c

index 8dc30a372e1ae45a5efd80d84b51015100eba198..487ad9483d5862da856ff33c35392b8d5827da09 100644 (file)
@@ -11717,7 +11717,7 @@ static void ggml_compute_forward_alibi_f32(
 
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
-    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    const int ne2 = src0->ne[2]; // n_head -> this is k
     //const int ne3 = src0->ne[3]; // 1 -> bsz
 
     const int n  = ggml_nrows(src0);
@@ -11728,8 +11728,9 @@ static void ggml_compute_forward_alibi_f32(
     const int nb2 = src0->nb[2];
     //const int nb3 = src0->nb[3];
 
-    assert(nb0 == sizeof(float));
-    assert(ne1 + n_past == ne0); (void) n_past;
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(ne1 + n_past == ne0);
+    GGML_ASSERT(n_head == ne2);
 
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11753,7 +11754,7 @@ static void ggml_compute_forward_alibi_f32(
                     m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
                 }
 
-                pdst[0] = (i-ne0+1) * m_k + src[0];
+                pdst[0] = i * m_k + src[0];
 
             }
         }
@@ -11782,7 +11783,7 @@ static void ggml_compute_forward_alibi_f16(
 
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
-    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    const int ne2 = src0->ne[2]; // n_head -> this is k
     //const int ne3 = src0->ne[3]; // 1 -> bsz
 
     const int n  = ggml_nrows(src0);
@@ -11793,8 +11794,9 @@ static void ggml_compute_forward_alibi_f16(
     const int nb2 = src0->nb[2];
     //const int nb3 = src0->nb[3];
 
-    assert(nb0 == sizeof(ggml_fp16_t));
-    assert(ne1 + n_past == ne0); (void) n_past;
+    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
+    GGML_ASSERT(n_head == ne2);
 
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11819,7 +11821,7 @@ static void ggml_compute_forward_alibi_f16(
                 }
 
                 // we return F32
-                pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
+                pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
             }
         }
     }