int n_ctx,
float freq_base,
float freq_scale,
+ float xpos_base,
+ bool xpos_downscale,
bool inplace) {
GGML_ASSERT(n_past >= 0);
bool is_node = false;
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- int32_t params[6] = { n_past, n_dims, mode, n_ctx };
+ int32_t params[8] = { n_past, n_dims, mode, n_ctx };
memcpy(params + 4, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float));
+ memcpy(params + 6, &xpos_base, sizeof(float));
+ memcpy(params + 7, &xpos_downscale, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE;
int n_dims,
int mode,
int n_ctx) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false);
}
struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode,
int n_ctx) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true);
}
struct ggml_tensor * ggml_rope_custom(
int n_ctx,
float freq_base,
float freq_scale) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false);
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false);
}
struct ggml_tensor * ggml_rope_custom_inplace(
int n_ctx,
float freq_base,
float freq_scale) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true);
+}
+
+struct ggml_tensor * ggml_rope_xpos_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ int n_dims,
+ float scale_base,
+ bool downscale) {
+ return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, scale_base, downscale, true);
}
// ggml_rope_back
int n_past,
int n_dims,
int mode,
- int n_ctx) {
+ int n_ctx,
+ float freq_base,
+ float freq_scale,
+ float xpos_base,
+ bool xpos_downscale) {
GGML_ASSERT(n_past >= 0);
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
- int32_t params[] = { n_past, n_dims, mode, n_ctx };
+ int32_t params[8] = { n_past, n_dims, mode, n_ctx };
+ memcpy(params + 4, &freq_base, sizeof(float));
+ memcpy(params + 5, &freq_scale, sizeof(float));
+ memcpy(params + 6, &xpos_base, sizeof(float));
+ memcpy(params + 7, &xpos_downscale, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE_BACK;
}
}
-
// ggml_compute_forward_clamp
static void ggml_compute_forward_clamp_f32(
float freq_base;
float freq_scale;
+ // these two only relevant for xPos RoPE:
+ float xpos_base;
+ bool xpos_downscale;
+
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&xpos_downscale, (int32_t *) dst->op_params + 7, sizeof(bool));
assert(n_past >= 0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
+ // zeta scaling for xPos only:
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
+ if (xpos_downscale) zeta = 1.0f / zeta;
theta *= theta_scale;
const float x0 = src[0];
const float x1 = src[1];
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[1] = x0*sin_theta + x1*cos_theta;
+ dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta;
+ dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
}
} else {
// TODO: this is probably wrong, but I can't figure it out ..
// dx = rope_back(dy, src1)
// src0 is dy, src1 contains options
+ float freq_base;
+ float freq_scale;
+
+ // these two only relevant for xPos RoPE:
+ float xpos_base;
+ bool xpos_downscale;
+
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
+ const int n_ctx = ((int32_t *) dst->op_params)[3];
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&xpos_downscale, (int32_t *) dst->op_params + 7, sizeof(bool));
assert(n_past >= 0);
// row index used to determine which thread to use
int ir = 0;
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
const bool is_neox = mode & 2;
if (ir++ < ir0) continue;
if (ir > ir1) break;
- float theta = (float)p;
+ float theta = freq_scale * (float)p;
if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
+ // zeta scaling for xPos only:
+ float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f;
+ if (xpos_downscale) zeta = 1.0f / zeta;
theta *= theta_scale;
const float dy0 = dy[0];
const float dy1 = dy[1];
- dx[0] = dy0*cos_theta + dy1*sin_theta;
- dx[1] = - dy0*sin_theta + dy1*cos_theta;
+ dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta;
+ dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
+ float freq_base, freq_scale, xpos_base;
+ bool xpos_downscale;
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
+ memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
+ memcpy(&xpos_downscale, (int32_t *) tensor->op_params + 7, sizeof(bool));
+
src0->grad = ggml_add_impl(ctx,
src0->grad,
ggml_rope_back(ctx,
n_past,
n_dims,
mode,
- n_ctx),
+ n_ctx,
+ freq_base,
+ freq_scale,
+ xpos_base,
+ xpos_downscale),
inplace);
}
} break;
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3];
+ float freq_base, freq_scale, xpos_base;
+ bool xpos_downscale;
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
+ memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
+ memcpy(&xpos_downscale, (int32_t *) tensor->op_params + 7, sizeof(bool));
+
src0->grad = ggml_add_impl(ctx,
src0->grad,
- ggml_rope(ctx,
+ ggml_rope_impl(ctx,
tensor->grad,
n_past,
n_dims,
mode,
- n_ctx),
+ n_ctx,
+ freq_base,
+ freq_scale,
+ xpos_base,
+ xpos_downscale,
+ false),
inplace);
}
} break;
--- /dev/null
+#include "ggml/ggml.h"
+
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+bool is_close(float a, float b, float epsilon) {
+ return fabs(a - b) < epsilon;
+}
+
+int main(int argc, char ** argv) {
+ const int n_threads = 1;
+ const int n_embd_head = 4; // aka head_dim
+ const int n_head = 1;
+ const int N = 8;
+
+ struct ggml_init_params params = {
+ .mem_size = 16*1024*1024,
+ .mem_buffer = NULL,
+ };
+
+ // memory allocation happens here
+ struct ggml_context * ctx = ggml_init(params);
+
+ struct ggml_tensor * Q = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N);
+ struct ggml_tensor * K = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N);
+
+ for (int i = 0; i < ggml_nelements(Q); i++) {
+ ((float*) Q->data)[i] = 2.0f;
+ ((float*) K->data)[i] = 2.0f;
+ }
+
+ struct ggml_tensor * Qx = ggml_rope_xpos_inplace(ctx, Q, 1, n_embd_head, 512.0f, false);
+ struct ggml_tensor * Kx = ggml_rope_xpos_inplace(ctx, K, 1, n_embd_head, 512.0f, true);
+
+ struct ggml_cgraph gf = ggml_build_forward(Qx);
+ ggml_build_forward_expand(&gf, Kx);
+ ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
+
+ // expected output for Qx:
+ // -0.6009 2.7568 1.9782 2.0182
+ // -2.6379 0.9815 1.9562 2.0361
+ // -2.2457 -1.6853 1.9341 2.0538
+ // 0.2043 -2.7934 1.9118 2.0712
+ // 2.4550 -1.3341 1.8894 2.0884
+ // 2.4430 1.3417 1.8668 2.1054
+ // 0.1905 2.7739 1.8440 2.1221
+ // -2.2257 1.6550 1.8212 2.1386
+
+ for (int i = 0; i < ggml_nelements(Q); i++) {
+ if (((float*) Qx->data)[i] > 0) printf(" ");
+ printf("%.4f ", ((float*) Qx->data)[i]);
+ if ((i+1) % n_embd_head == 0) printf("\n");
+ }
+ printf("\n");
+
+ GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 0], -2.2257f, 0.0001f));
+ GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 1], 1.6550f, 0.0001f));
+ GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 2], 1.8212f, 0.0001f));
+ GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 3], 2.1386f, 0.0001f));
+
+ // expected output for Kx:
+ // -0.6038 2.7703 1.9816 2.0216
+ // -2.6639 0.9911 1.9630 2.0431
+ // -2.2789 -1.7103 1.9441 2.0644
+ // 0.2083 -2.8486 1.9251 2.0856
+ // 2.5158 -1.3671 1.9057 2.1065
+ // 2.5158 1.3816 1.8862 2.1273
+ // 0.1972 2.8705 1.8665 2.1479
+ // -2.3146 1.7211 1.8465 2.1684
+
+ for (int i = 0; i < ggml_nelements(K); i++) {
+ if (((float*) Kx->data)[i] > 0) printf(" ");
+ printf("%.4f ", ((float*) Kx->data)[i]);
+ if ((i+1) % n_embd_head == 0) printf("\n");
+ }
+ printf("\n");
+
+ GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 0], -2.3146f, 0.0001f));
+ GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 1], 1.7211f, 0.0001f));
+ GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 2], 1.8465f, 0.0001f));
+ GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 3], 2.1684f, 0.0001f));
+
+ ggml_free(ctx);
+
+ return 0;
+}