]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : support GLM-style RoPE (#383)
authorJiahao Li <redacted>
Fri, 14 Jul 2023 12:10:59 +0000 (20:10 +0800)
committerGitHub <redacted>
Fri, 14 Jul 2023 12:10:59 +0000 (15:10 +0300)
src/ggml-cuda.cu

index dc4b773a66a4451dab6f209cc44751dc852adaeb..25d4af0e69f6b4f4ff994c061fde5852cdbe9629 100644 (file)
@@ -1668,6 +1668,40 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
+static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
+    const int col = blockDim.x*blockIdx.x + threadIdx.x;
+    const int half_n_dims = ncols/4;
+
+    if (col >= half_n_dims) {
+        return;
+    }
+
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i = row*ncols + col;
+
+    const float col_theta_scale = powf(theta_scale, col);
+
+    const float theta = p*col_theta_scale;
+    const float sin_theta = sinf(theta);
+    const float cos_theta = cosf(theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + half_n_dims];
+
+    dst[i + 0]           = x0*cos_theta - x1*sin_theta;
+    dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
+
+    const float block_theta = block_p*col_theta_scale;
+    const float sin_block_theta = sinf(block_theta);
+    const float cos_block_theta = cosf(block_theta);
+
+    const float x2 = x[i + half_n_dims * 2];
+    const float x3 = x[i + half_n_dims * 3];
+
+    dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
+    dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
+}
+
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int row = blockDim.y*blockIdx.y + threadIdx.y;
@@ -2065,6 +2099,14 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
 }
 
+static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
+    GGML_ASSERT(nrows % 4 == 0);
+    const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+    const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(num_blocks_x, nrows, 1);
+    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
+}
+
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
     const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -2619,13 +2661,21 @@ inline void ggml_cuda_op_rope(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
-    GGML_ASSERT(mode == 0);
+    const int n_ctx  = ((int32_t *) src1->data)[3];
 
     const float theta_scale = powf(10000.0, -2.0f/n_dims);
     const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
 
+    bool is_glm = mode & 4;
+
     // compute
-    rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+    if (is_glm) {
+        const float id_p = min(p, n_ctx - 2.f);
+        const float block_p = max(p - (n_ctx - 2.f), 0.f);
+        rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
+    } else {
+        rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+    }
 
     (void) dst;
     (void) src0_ddq_i;