]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: Implemented row flattening for non-glm RoPE (#2468)
authorJohannes Gäßler <redacted>
Mon, 31 Jul 2023 12:32:30 +0000 (14:32 +0200)
committerGitHub <redacted>
Mon, 31 Jul 2023 12:32:30 +0000 (14:32 +0200)
ggml-cuda.cu

index 3f111565ae4448ae6aa36b5edae7febd07244041..bcdff3640c13dfb9e4f95000ba97a989ecf93254 100644 (file)
@@ -3150,7 +3150,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
 }
 
 // rope == RoPE == rotary positional embedding
-static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
+static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0,
+                                const float p_delta, const int p_delta_rows, const float theta_scale) {
     const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
 
     if (col >= ncols) {
@@ -3160,7 +3161,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     const int row = blockDim.y*blockIdx.y + threadIdx.y;
     const int i = row*ncols + col;
 
-    const float theta = p*powf(theta_scale, col/2);
+    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
     const float sin_theta = sinf(theta);
     const float cos_theta = cosf(theta);
 
@@ -3764,12 +3765,13 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
 }
 
-static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
+static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
+                          const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
     GGML_ASSERT(nrows % 2 == 0);
     const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(num_blocks_x, nrows, 1);
-    rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
+    rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
 }
 
 static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
@@ -4465,6 +4467,7 @@ inline void ggml_cuda_op_rope(
     GGML_ASSERT(dst_ddf_i != nullptr);
 
     const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
     const int64_t i01_diff = i01_high - i01_low;
 
     const int n_past = ((int32_t *) dst->op_params)[0];
@@ -4478,17 +4481,18 @@ inline void ggml_cuda_op_rope(
     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
 
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
-    const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
 
-    bool is_glm = mode & 4;
+    const bool is_glm = mode & 4;
 
     // compute
     if (is_glm) {
+        const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
         const float id_p = min(p, n_ctx - 2.f);
         const float block_p = max(p - (n_ctx - 2.f), 0.f);
         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
     } else {
-        rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+        const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
+        rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
     }
 
     (void) src1;
@@ -5103,7 +5107,10 @@ void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml
 
 void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results
+
+    const int mode = ((int32_t *) dst->op_params)[2];
+    const bool is_glm = mode & 4;
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
 }
 
 void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {