const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
- //const int ne2 = src0->ne[2]; // n_head -> this is k
+ const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0);
const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
- assert(nb0 == sizeof(float));
- assert(ne1 + n_past == ne0); (void) n_past;
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(ne1 + n_past == ne0);
+ GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
}
- pdst[0] = (i-ne0+1) * m_k + src[0];
+ pdst[0] = i * m_k + src[0];
}
}
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
- //const int ne2 = src0->ne[2]; // n_head -> this is k
+ const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0);
const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
- assert(nb0 == sizeof(ggml_fp16_t));
- assert(ne1 + n_past == ne0); (void) n_past;
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
+ GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
}
// we return F32
- pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
+ pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
}
}
}