const int32_t * pos = (const int32_t *) src1->data;
+ int64_t last_i2 = -1;
+
for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
-
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
- if (!mrope_used) {
- const int64_t p = pos[i2];
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
- }
- else {
- const int64_t p_t = pos[i2];
- const int64_t p_h = pos[i2 + ne2];
- const int64_t p_w = pos[i2 + ne2 * 2];
- const int64_t p_e = pos[i2 + ne2 * 3];
- ggml_mrope_cache_init(
- p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
- }
-
for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
- if (ir++ < ir0) continue;
+ if (ir++ < ir0) continue; // skip rows mapped to other threads
if (ir > ir1) break;
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+ if (last_i2 != i2) {
+ if (!mrope_used) {
+ const int64_t p = pos[i2];
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ }
+ else {
+ const int64_t p_t = pos[i2];
+ const int64_t p_h = pos[i2 + ne2];
+ const int64_t p_w = pos[i2 + ne2 * 2];
+ const int64_t p_e = pos[i2 + ne2 * 3];
+ ggml_mrope_cache_init(
+ p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ }
+
+ last_i2 = i2;
+ }
+
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);