return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
}
+static void ggml_rope_cache_init(
+ float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+ float * cache, float sin_sign, float theta_scale
+) {
+ float theta = theta_base;
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+ rope_yarn(
+ theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+ );
+ cache[i0 + 1] *= sin_sign;
+
+ theta *= theta_scale;
+ }
+}
+
void ggml_rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
+
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
+ ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ }
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
- float cos_theta, sin_theta;
- rope_yarn(
- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
- );
- sin_theta *= sin_sign;
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
- theta_base *= theta_scale;
-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
+
+ float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+ if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox
+ ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+ }
+
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
- float cos_theta, sin_theta;
- rope_yarn(
- theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
- );
- sin_theta *= sin_sign;
-
- theta_base *= theta_scale;
+ const float cos_theta = cache[i0 + 0];
+ const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
}
} break;
case GGML_OP_SOFT_MAX:
+ case GGML_OP_ROPE:
{
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} break;