sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale);
}
- float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
+
+ float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
- vst1_f32(s, vget_low_f32(sumv2));
+ vst1_f32(s, vget_low_f32 (sumv2));
vst1_f32(s + bs, vget_high_f32(sumv2));
+
return;
}
#endif
// This is the size of the rest of the dimensions of the result
const int64_t nr1 = ne1 * ne2 * ne3;
- // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
- int64_t num_rows_per_vec_dot = vec_dot_num_rows;
- // TODO: currently the mmla kernels support only even numbered rows/cols.
- // this check can be removed once they are extended to support odd numbered rows/cols too
- if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
- num_rows_per_vec_dot = 1;
- }
-
// Now select a reasonable chunk size.
int chunk_size = 16;
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
+ // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
+
+ // TODO: currently the mmla kernels support only even numbered rows/cols.
+ // this check can be removed once they are extended to support odd numbered rows/cols too
+ if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
+ num_rows_per_vec_dot = 1;
+ }
+
ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
if (nth >= nchunk0 * nchunk1) {