}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
- mul->src[0]->ne[1] != rms_norm->ne[1]) {
+ !ggml_are_same_shape(mul->src[0], rms_norm)) {
return false;
}
// rms_norm shader assumes contiguous rows
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
if (do_multiply) {
- [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
- data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
+ if (ncols > p.ne10) {
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
+ }
+ } else {
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
+ }
}
} else {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {