#define GGML_GELU_FP16
#define GGML_SOFT_MAX_UNROLL 4
-#define GGML_VEC_DOT_UNROLL 4
+#define GGML_VEC_DOT_UNROLL 2
#ifdef GGML_USE_ACCELERATE
// uncomment to use vDSP for soft max computation
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
- const ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL] = { xv };
+ ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
- for (int i = 1; i < GGML_VEC_DOT_UNROLL; ++i) {
+ for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
}
S[i] = -INFINITY;
}
- // looks like unrolling here does not help
-#if 1
- for (int ic = 0; ic < nek1; ++ic) {
- // k indices
- const int ik3 = iq3;
- const int ik2 = iq2;
- const int ik1 = ic;
-
- // S indices
- const int i1 = ik1;
-
- ggml_vec_dot_f16(neq0,
- S + i1,
- (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
- }
-#else
- GGML_ASSERT(nek1 % GGML_VEC_DOT_UNROLL == 0);
-
- for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
- // k indices
- const int ik3 = iq3;
- const int ik2 = iq2;
- const int ik1 = ic;
+ if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) {
+ for (int ic = 0; ic < nek1; ++ic) {
+ // k indices
+ const int ik3 = iq3;
+ const int ik2 = iq2;
+ const int ik1 = ic;
- // S indices
- const int i1 = ik1;
+ // S indices
+ const int i1 = ik1;
- ggml_vec_dot_f16_unroll(neq0, nbk1,
- S + i1,
- ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
- (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+ ggml_vec_dot_f16(neq0,
+ S + i1,
+ (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+ }
+ } else {
+ for (int ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
+ // k indices
+ const int ik3 = iq3;
+ const int ik2 = iq2;
+ const int ik1 = ic;
+
+ // S indices
+ const int i1 = ik1;
+
+ ggml_vec_dot_f16_unroll(neq0, nbk1,
+ S + i1,
+ ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
+ (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
+ }
}
-#endif
// scale
ggml_vec_scale_f32(nek1, S, scale);
S16[i] = GGML_FP32_TO_FP16(S[i]);
}
- GGML_ASSERT(nev1 % GGML_VEC_DOT_UNROLL == 0);
+ if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
+ for (int ic = 0; ic < nev1; ++ic) {
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
- for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
+ ggml_vec_dot_f16(nek1,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ S16);
+ }
+ } else {
+ for (int ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
- ggml_vec_dot_f16_unroll(nek1, nbv1,
- (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
- S16);
+ ggml_vec_dot_f16_unroll(nek1, nbv1,
+ (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
+ S16);
+ }
}
}
}