From: Daulet Zhanguzin Date: Tue, 11 Jul 2023 17:26:22 +0000 (-0700) Subject: ggml : fix Alibi implementation (#351) X-Git-Tag: upstream/0.0.1642~1338 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=2c1536bdf27616b9cf966ec85e7f4df615ef3efe;p=pkg%2Fggml%2Fsources%2Fggml ggml : fix Alibi implementation (#351) * correct Alibi implementation * update f16 too --- diff --git a/src/ggml.c b/src/ggml.c index 8dc30a37..487ad948 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -11717,7 +11717,7 @@ static void ggml_compute_forward_alibi_f32( const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne1 = src0->ne[1]; // seq_len_without_past - //const int ne2 = src0->ne[2]; // n_head -> this is k + const int ne2 = src0->ne[2]; // n_head -> this is k //const int ne3 = src0->ne[3]; // 1 -> bsz const int n = ggml_nrows(src0); @@ -11728,8 +11728,9 @@ static void ggml_compute_forward_alibi_f32( const int nb2 = src0->nb[2]; //const int nb3 = src0->nb[3]; - assert(nb0 == sizeof(float)); - assert(ne1 + n_past == ne0); (void) n_past; + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(ne1 + n_past == ne0); + GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); @@ -11753,7 +11754,7 @@ static void ggml_compute_forward_alibi_f32( m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); } - pdst[0] = (i-ne0+1) * m_k + src[0]; + pdst[0] = i * m_k + src[0]; } } @@ -11782,7 +11783,7 @@ static void ggml_compute_forward_alibi_f16( const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne1 = src0->ne[1]; // seq_len_without_past - //const int ne2 = src0->ne[2]; // n_head -> this is k + const int ne2 = src0->ne[2]; // n_head -> this is k //const int ne3 = src0->ne[3]; // 1 -> bsz const int n = ggml_nrows(src0); @@ -11793,8 +11794,9 @@ static void ggml_compute_forward_alibi_f16( const int nb2 = src0->nb[2]; //const int nb3 = src0->nb[3]; - assert(nb0 == sizeof(ggml_fp16_t)); - assert(ne1 + n_past == ne0); (void) n_past; + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; + GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); @@ -11819,7 +11821,7 @@ static void ggml_compute_forward_alibi_f16( } // we return F32 - pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]); + pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); } } }