From: Jan Ploski Date: Tue, 22 Aug 2023 08:45:20 +0000 (+0200) Subject: ggml : implementation of xPos RoPE (#441); also extends ggml_rope_back with additiona... X-Git-Tag: upstream/0.0.1642~1273 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=896b089e63449efcc37c52bd84913dd2295c82f4;p=pkg%2Fggml%2Fsources%2Fggml ggml : implementation of xPos RoPE (#441); also extends ggml_rope_back with additional parameters (breaking API change); does not include CUDA version (#442) --- diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 58598ebf..5c838689 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -1217,6 +1217,15 @@ extern "C" { float freq_base, float freq_scale); + // xPos RoPE, in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + float scale_base, + bool downscale); + // rotary position embedding backward, i.e compute dx from dy // a - dy GGML_API struct ggml_tensor * ggml_rope_back( @@ -1225,7 +1234,11 @@ extern "C" { int n_past, int n_dims, int mode, - int n_ctx); + int n_ctx, + float freq_base, + float freq_scale, + float xpos_base, + bool xpos_downscale); // alibi position embedding // in-place, returns view(a) diff --git a/src/ggml.c b/src/ggml.c index 7b5922e3..6a4b6ed4 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -6715,6 +6715,8 @@ static struct ggml_tensor * ggml_rope_impl( int n_ctx, float freq_base, float freq_scale, + float xpos_base, + bool xpos_downscale, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6725,9 +6727,11 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[6] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { n_past, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); + memcpy(params + 6, &xpos_base, sizeof(float)); + memcpy(params + 7, &xpos_downscale, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -6744,7 +6748,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); } struct ggml_tensor * ggml_rope_inplace( @@ -6754,7 +6758,7 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); } struct ggml_tensor * ggml_rope_custom( @@ -6766,7 +6770,7 @@ struct ggml_tensor * ggml_rope_custom( int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); } struct ggml_tensor * ggml_rope_custom_inplace( @@ -6778,7 +6782,17 @@ struct ggml_tensor * ggml_rope_custom_inplace( int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); +} + +struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + float scale_base, + bool downscale) { + return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, scale_base, downscale, true); } // ggml_rope_back @@ -6789,7 +6803,11 @@ struct ggml_tensor * ggml_rope_back( int n_past, int n_dims, int mode, - int n_ctx) { + int n_ctx, + float freq_base, + float freq_scale, + float xpos_base, + bool xpos_downscale) { GGML_ASSERT(n_past >= 0); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -6801,7 +6819,11 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + memcpy(params + 4, &freq_base, sizeof(float)); + memcpy(params + 5, &freq_scale, sizeof(float)); + memcpy(params + 6, &xpos_base, sizeof(float)); + memcpy(params + 7, &xpos_downscale, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE_BACK; @@ -12065,7 +12087,6 @@ static void ggml_compute_forward_alibi( } } - // ggml_compute_forward_clamp static void ggml_compute_forward_clamp_f32( @@ -12154,12 +12175,18 @@ static void ggml_compute_forward_rope_f32( float freq_base; float freq_scale; + // these two only relevant for xPos RoPE: + float xpos_base; + bool xpos_downscale; + const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&xpos_downscale, (int32_t *) dst->op_params + 7, sizeof(bool)); assert(n_past >= 0); @@ -12231,6 +12258,9 @@ static void ggml_compute_forward_rope_f32( for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); + // zeta scaling for xPos only: + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + if (xpos_downscale) zeta = 1.0f / zeta; theta *= theta_scale; @@ -12240,8 +12270,8 @@ static void ggml_compute_forward_rope_f32( const float x0 = src[0]; const float x1 = src[1]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; + dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; } } else { // TODO: this is probably wrong, but I can't figure it out .. @@ -12435,9 +12465,21 @@ static void ggml_compute_forward_rope_back_f32( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options + float freq_base; + float freq_scale; + + // these two only relevant for xPos RoPE: + float xpos_base; + bool xpos_downscale; + const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&xpos_downscale, (int32_t *) dst->op_params + 7, sizeof(bool)); assert(n_past >= 0); @@ -12463,7 +12505,7 @@ static void ggml_compute_forward_rope_back_f32( // row index used to determine which thread to use int ir = 0; - const float theta_scale = powf(10000.0, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f/n_dims); const bool is_neox = mode & 2; @@ -12474,12 +12516,15 @@ static void ggml_compute_forward_rope_back_f32( if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = (float)p; + float theta = freq_scale * (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); + // zeta scaling for xPos only: + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + if (xpos_downscale) zeta = 1.0f / zeta; theta *= theta_scale; @@ -12489,8 +12534,8 @@ static void ggml_compute_forward_rope_back_f32( const float dy0 = dy[0]; const float dy1 = dy[1]; - dx[0] = dy0*cos_theta + dy1*sin_theta; - dx[1] = - dy0*sin_theta + dy1*cos_theta; + dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta; + dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta; } } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { @@ -15967,6 +16012,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; + float freq_base, freq_scale, xpos_base; + bool xpos_downscale; + memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&xpos_downscale, (int32_t *) tensor->op_params + 7, sizeof(bool)); + src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope_back(ctx, @@ -15974,7 +16026,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor n_past, n_dims, mode, - n_ctx), + n_ctx, + freq_base, + freq_scale, + xpos_base, + xpos_downscale), inplace); } } break; @@ -15985,14 +16041,26 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; + float freq_base, freq_scale, xpos_base; + bool xpos_downscale; + memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&xpos_downscale, (int32_t *) tensor->op_params + 7, sizeof(bool)); + src0->grad = ggml_add_impl(ctx, src0->grad, - ggml_rope(ctx, + ggml_rope_impl(ctx, tensor->grad, n_past, n_dims, mode, - n_ctx), + n_ctx, + freq_base, + freq_scale, + xpos_base, + xpos_downscale, + false), inplace); } } break; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 33eda08f..8d9d1f79 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -325,3 +325,12 @@ if (MSVC) endif() add_test(NAME ${TEST_TARGET} COMMAND $) set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + +# +# test-xpos + +set(TEST_TARGET test-xpos) +add_executable(${TEST_TARGET} ${TEST_TARGET}.c) +target_link_libraries(${TEST_TARGET} PRIVATE ggml) +add_test(NAME ${TEST_TARGET} COMMAND $) +set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") diff --git a/tests/test-xpos.c b/tests/test-xpos.c new file mode 100644 index 00000000..a8c64e55 --- /dev/null +++ b/tests/test-xpos.c @@ -0,0 +1,87 @@ +#include "ggml/ggml.h" + +#include +#include +#include + +bool is_close(float a, float b, float epsilon) { + return fabs(a - b) < epsilon; +} + +int main(int argc, char ** argv) { + const int n_threads = 1; + const int n_embd_head = 4; // aka head_dim + const int n_head = 1; + const int N = 8; + + struct ggml_init_params params = { + .mem_size = 16*1024*1024, + .mem_buffer = NULL, + }; + + // memory allocation happens here + struct ggml_context * ctx = ggml_init(params); + + struct ggml_tensor * Q = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N); + struct ggml_tensor * K = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, N); + + for (int i = 0; i < ggml_nelements(Q); i++) { + ((float*) Q->data)[i] = 2.0f; + ((float*) K->data)[i] = 2.0f; + } + + struct ggml_tensor * Qx = ggml_rope_xpos_inplace(ctx, Q, 1, n_embd_head, 512.0f, false); + struct ggml_tensor * Kx = ggml_rope_xpos_inplace(ctx, K, 1, n_embd_head, 512.0f, true); + + struct ggml_cgraph gf = ggml_build_forward(Qx); + ggml_build_forward_expand(&gf, Kx); + ggml_graph_compute_with_ctx(ctx, &gf, n_threads); + + // expected output for Qx: + // -0.6009 2.7568 1.9782 2.0182 + // -2.6379 0.9815 1.9562 2.0361 + // -2.2457 -1.6853 1.9341 2.0538 + // 0.2043 -2.7934 1.9118 2.0712 + // 2.4550 -1.3341 1.8894 2.0884 + // 2.4430 1.3417 1.8668 2.1054 + // 0.1905 2.7739 1.8440 2.1221 + // -2.2257 1.6550 1.8212 2.1386 + + for (int i = 0; i < ggml_nelements(Q); i++) { + if (((float*) Qx->data)[i] > 0) printf(" "); + printf("%.4f ", ((float*) Qx->data)[i]); + if ((i+1) % n_embd_head == 0) printf("\n"); + } + printf("\n"); + + GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 0], -2.2257f, 0.0001f)); + GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 1], 1.6550f, 0.0001f)); + GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 2], 1.8212f, 0.0001f)); + GGML_ASSERT(is_close(((float*) Qx->data)[7 * n_embd_head + 3], 2.1386f, 0.0001f)); + + // expected output for Kx: + // -0.6038 2.7703 1.9816 2.0216 + // -2.6639 0.9911 1.9630 2.0431 + // -2.2789 -1.7103 1.9441 2.0644 + // 0.2083 -2.8486 1.9251 2.0856 + // 2.5158 -1.3671 1.9057 2.1065 + // 2.5158 1.3816 1.8862 2.1273 + // 0.1972 2.8705 1.8665 2.1479 + // -2.3146 1.7211 1.8465 2.1684 + + for (int i = 0; i < ggml_nelements(K); i++) { + if (((float*) Kx->data)[i] > 0) printf(" "); + printf("%.4f ", ((float*) Kx->data)[i]); + if ((i+1) % n_embd_head == 0) printf("\n"); + } + printf("\n"); + + GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 0], -2.3146f, 0.0001f)); + GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 1], 1.7211f, 0.0001f)); + GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 2], 1.8465f, 0.0001f)); + GGML_ASSERT(is_close(((float*) Kx->data)[7 * n_embd_head + 3], 2.1684f, 0.0001f)); + + ggml_free(ctx); + + return 0; +}