// if mode & 4 == 1, ChatGLM style
//
// b is an int32 vector with size a->ne[2], it contains the positions
+ // c is freq factors (e.g. phi3-128k), (optional)
GGML_API struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_ctx);
// custom RoPE
- GGML_API struct ggml_tensor * ggml_rope_custom(
+ GGML_API struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
float beta_slow);
// in-place, returns view(a)
- GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+ GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
float beta_fast,
float beta_slow);
- // compute correction dims for YaRN RoPE scaling
- GGML_CALL void ggml_rope_yarn_corr_dims(
- int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode,
+ int n_ctx,
+ int n_orig_ctx,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow),
+ "use ggml_rope_ext instead");
- // xPos RoPE, in-place, returns view(a)
- GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
- float base,
- bool down);
+ int mode,
+ int n_ctx,
+ int n_orig_ctx,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow),
+ "use ggml_rope_ext_inplace instead");
+
+ // compute correction dims for YaRN RoPE scaling
+ GGML_CALL void ggml_rope_yarn_corr_dims(
+ int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]);
// rotary position embedding backward, i.e compute dx from dy
// a - dy
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
-template<typename T, bool has_pos>
+template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0;
- const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
+
+ const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) {
- rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
- theta_scale, inv_ndims
- );
+ if (freq_factors == nullptr) {
+ rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, inv_ndims, freq_factors
+ );
+ } else {
+ rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, inv_ndims, freq_factors
+ );
+ }
} else {
- rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
- x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
- theta_scale, inv_ndims
- );
+ if (freq_factors == nullptr) {
+ rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, inv_ndims, freq_factors
+ );
+ } else {
+ rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+ theta_scale, inv_ndims, freq_factors
+ );
+ }
}
}
static void rope_neox_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
- rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+ rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
static void rope_neox_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
+ float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
) {
- rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
+ rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
+
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
- const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0];
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ const float * freq_factors = nullptr;
const int32_t * pos = nullptr;
- if ((mode & 1) == 0) {
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
- GGML_ASSERT(src1->ne[0] == ne2);
- pos = (const int32_t *) src1_d;
- }
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
+ if (is_neox) {
+ pos = (const int32_t *) src1_d;
+
+ if (src2 != nullptr) {
+ freq_factors = (const float *) src2->data;
+ }
+ } else {
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
+ }
+
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
- attn_factor, corr_dims, stream
+ attn_factor, corr_dims, freq_factors, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
- attn_factor, corr_dims, stream
+ attn_factor, corr_dims, freq_factors, stream
);
} else {
GGML_ASSERT(false);
} break;
case GGML_OP_ROPE:
{
+#pragma message("TODO: implement phi3 frequency factors support")
+#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
+ GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
+
GGML_ASSERT(ne10 == ne02);
GGML_ASSERT(src0t == dstt);
// const int n_past = ((int32_t *) dst->op_params)[0];
const int64_t ne10 = src1 ? src1->ne[0] : 0;
const int64_t ne11 = src1 ? src1->ne[1] : 0;
const int64_t ne12 = src1 ? src1->ne[2] : 0;
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
- const int64_t ne0 = dst ? dst->ne[0] : 0;
- const int64_t ne1 = dst ? dst->ne[1] : 0;
- const int64_t ne2 = dst ? dst->ne[2] : 0;
- const int64_t ne3 = dst ? dst->ne[3] : 0;
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
+
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
+
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const int n_as = src0->ne[2];
// src2 = ids
- const int64_t ne20 = src2->ne[0];
- const int64_t ne21 = src2->ne[1];
- const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
- const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
-
- const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
- const uint64_t nb21 = src2->nb[1];
- const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
- const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
-
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
GGML_ASSERT(src2t == GGML_TYPE_I32);
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ const bool is_neox = mode & 2;
+ const bool is_glm = mode & 4;
+
+ GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
+
+ if (!is_neox) {
+ GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
+ }
+
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
+ if (id_src2 != nil) {
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
+ [encoder setBytes:&mode length:sizeof( int) atIndex:22];
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
-
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
+ device const float * src2,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
kernel void kernel_rope(
device const void * src0,
device const int32_t * src1,
+ device const float * src2,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
+ const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
+
+ const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
- const float theta = theta_0 * pow(freq_base, cur_rot);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
ggml_tensor *dst, const float *src0_dd,
const float *src1_dd, float *dst_dd,
const dpct::queue_ptr &main_stream) {
+#pragma message("TODO: implement phi3 frequency factors support")
+#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
+ GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+#pragma message("TODO: implement phi3 frequency factors support")
+#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
+ GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
+
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
// const int n_ctx = ((int32_t *) dst->op_params)[3];
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
+ if (c) {
+ GGML_ASSERT(c->type == GGML_TYPE_F32);
+ GGML_ASSERT(c->ne[0] >= n_dims / 2);
+ }
+
bool is_node = false;
if (a->grad) {
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
+ result->src[2] = c;
return result;
}
int mode,
int n_ctx) {
return ggml_rope_impl(
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false
);
}
int mode,
int n_ctx) {
return ggml_rope_impl(
- ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
+ ctx, a, b, NULL, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true
);
}
-struct ggml_tensor * ggml_rope_custom(
+struct ggml_tensor * ggml_rope_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
);
}
-struct ggml_tensor * ggml_rope_custom_inplace(
+struct ggml_tensor * ggml_rope_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
float beta_fast,
float beta_slow) {
return ggml_rope_impl(
- ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ctx, a, b, c, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
);
}
-struct ggml_tensor * ggml_rope_xpos_inplace(
+struct ggml_tensor * ggml_rope_custom(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
- float base,
- bool down) {
- return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
+ int mode,
+ int n_ctx,
+ int n_orig_ctx,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false
+ );
+}
+
+struct ggml_tensor * ggml_rope_custom_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ int n_dims,
+ int mode,
+ int n_ctx,
+ int n_orig_ctx,
+ float freq_base,
+ float freq_scale,
+ float ext_factor,
+ float attn_factor,
+ float beta_fast,
+ float beta_slow) {
+ return ggml_rope_impl(
+ ctx, a, b, NULL, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true
+ );
}
// ggml_rope_back
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
+ struct ggml_tensor * c,
int n_dims,
int mode,
int n_ctx,
GGML_ASSERT(ggml_is_vector(b));
GGML_ASSERT(b->type == GGML_TYPE_I32);
GGML_ASSERT(a->ne[2] == b->ne[0]);
+ GGML_ASSERT(c == NULL && "freq factors not implemented yet");
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
+ const struct ggml_tensor * src2 = dst->src[2];
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
+ const float * freq_factors = NULL;
+ if (is_neox) {
+ if (src2 != NULL) {
+ GGML_ASSERT(src2->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+ freq_factors = (const float *) src2->data;
+ }
+ } else {
+ GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
+ }
+
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
+ float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
- theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
+ theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) {
struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1];
+ struct ggml_tensor * src2 = tensor->src[2];
switch (tensor->op) {
case GGML_OP_DUP:
ggml_rope_back(ctx,
tensor->grad,
src1,
+ src2,
n_dims,
mode,
n_ctx,
ggml_rope_impl(ctx,
tensor->grad,
src1,
+ src2,
n_dims,
mode,
n_ctx,
masked);
}
- struct ggml_tensor * src2 = tensor->src[2];
const int64_t elem_q = ggml_nelements(src0);
const int64_t elem_k = ggml_nelements(src1);
const int64_t elem_v = ggml_nelements(src2);
struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
- Qcur = ggml_rope_custom(
- ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos,
+ Qcur = ggml_rope_ext(
+ ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos, nullptr,
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
- Kcur = ggml_rope_custom(
- ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
+ Kcur = ggml_rope_ext(
+ ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
// using mode = 2 for neox mode
- Qcur = ggml_rope_custom(
- ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
+ Qcur = ggml_rope_ext(
+ ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
- Kcur = ggml_rope_custom(
- ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
+ Kcur = ggml_rope_ext(
+ ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, 0, hp.n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);