return (1 - MIN(1, MAX(0, y)));
}
-static void rope_cache_init(const float theta_base,
- float freq_scale,
- const float * freq_factors,
- float * corr_dims,
- uint32_t ne0,
- float ext_factor,
- float mscale,
- float * cache,
- float theta_scale) {
+static void rope_cache_init(const float theta_base,
+ const float freq_scale,
+ const float * freq_factors,
+ float * corr_dims,
+ const uint32_t ne0,
+ const float ext_factor,
+ const float mscale,
+ float * cache,
+ const float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta = theta_base;
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
- float theta2 = theta_interp;
+ float theta_final = theta_interp;
+ float mscale_final = mscale;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
- theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+ theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
- mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+ mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
- cache[i0 + 0] = cosf(theta2) * mscale;
- cache[i0 + 1] = sinf(theta2) * mscale;
+ cache[i0 + 0] = cosf(theta_final) * mscale_final;
+ cache[i0 + 1] = sinf(theta_final) * mscale_final;
theta *= theta_scale;
}
}
static void hvx_calc_rope_neox_f32(const float * restrict src0,
- float * restrict dst,
- const int num_elems,
- const float * restrict theta_cache) {
+ float * restrict dst,
+ const int num_elems,
+ const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
- *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
+ *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
src0_curr += VLEN;
const uint32_t ir1,
int nth,
int ith,
- int opt_path) {
+ const int opt_path) {
struct htp_ops_context * octx = rope_ctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
- const int32_t mode = rope_ctx->mode;
- const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
+ const int32_t mode = rope_ctx->mode;
+ const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
htp_rope_preamble;
freq_factors = (const float *) src2->data;
}
- int ir = 0;
-
+ const uint32_t i1_end = MIN(ir1, ne1);
+ const int32_t half_dims = rope_ctx->n_dims / 2;
+ const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
const int32_t p = pos[i2];
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
- for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads
- if (ir++ < ir0) {
- continue;
- }
- if (ir > ir1) {
- break;
- }
-
+ for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
} else {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
}
+
+ src_loc += rope_ctx->n_dims;
+ dst_data_loc += rope_ctx->n_dims;
} else {
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
const float cos_theta = wp0[i0 + 0];
if (is_neox) {
const float x0 = src_loc[0];
- const float x1 = src_loc[rope_ctx->n_dims/2];
+ const float x1 = src_loc[half_dims];
- dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
- dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
+ dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
+ dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
src_loc += 1;
dst_data_loc += 1;
dst_data_loc += 2;
}
}
- }
-
- for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
- dst_data_loc[0] = src_loc[0];
- dst_data_loc[1] = src_loc[1];
- src_loc += 2;
- dst_data_loc += 2;
+ src_loc += (is_neox ? half_dims : 0);
+ dst_data_loc += (is_neox ? half_dims : 0);
}
+
+ // TODO: use simd to speed up the remaining elements copy
+ memcpy(dst_data_loc, src_loc, remain_bytes);
}
}
}