default: GGML_ASSERT(false);
};
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
- [encoder setBytes:&freq_base length:sizeof(float) atIndex:22];
- [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23];
- [encoder setBytes:&ext_factor length:sizeof(float) atIndex:24];
- [encoder setBytes:&attn_factor length:sizeof(float) atIndex:25];
- [encoder setBytes:&beta_fast length:sizeof(float) atIndex:26];
- [encoder setBytes:&beta_slow length:sizeof(float) atIndex:27];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
- float * cos_theta, float * sin_theta
+ thread float * cos_theta, thread float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
- ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
- mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
- *cos_theta = cosf(theta) * mscale;
- *sin_theta = sinf(theta) * mscale;
+ *cos_theta = cos(theta) * mscale;
+ *sin_theta = sin(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
constant int & n_past,
constant int & n_dims,
constant int & mode,
+ constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
+ constant float & ext_factor,
+ constant float & attn_factor,
+ constant float & beta_fast,
+ constant float & beta_slow,
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]);
constant int & n_past,
constant int & n_dims,
constant int & mode,
+ constant int & n_orig_ctx,
constant float & freq_base,
constant float & freq_scale,
constant float & ext_factor,