From: Jiahao Li Date: Fri, 14 Jul 2023 12:10:59 +0000 (+0800) Subject: cuda : support GLM-style RoPE (#383) X-Git-Tag: upstream/0.0.1642~1316 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=da3015e52487d8fdb1a05de0879cdc3fc8976bc2;p=pkg%2Fggml%2Fsources%2Fggml cuda : support GLM-style RoPE (#383) --- diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index dc4b773a..25d4af0e 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -1668,6 +1668,40 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } +static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + const int half_n_dims = ncols/4; + + if (col >= half_n_dims) { + return; + } + + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int i = row*ncols + col; + + const float col_theta_scale = powf(theta_scale, col); + + const float theta = p*col_theta_scale; + const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + half_n_dims]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; + + const float block_theta = block_p*col_theta_scale; + const float sin_block_theta = sinf(block_theta); + const float cos_block_theta = cosf(block_theta); + + const float x2 = x[i + half_n_dims * 2]; + const float x3 = x[i + half_n_dims * 3]; + + dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; + dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; +} + static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int row = blockDim.y*blockIdx.y + threadIdx.y; @@ -2065,6 +2099,14 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i rope_f32<<>>(x, dst, ncols, p, theta_scale); } +static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { + GGML_ASSERT(nrows % 4 == 0); + const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1); + const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(num_blocks_x, nrows, 1); + rope_glm_f32<<>>(x, dst, ncols, p, block_p, theta_scale); +} + static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; @@ -2619,13 +2661,21 @@ inline void ggml_cuda_op_rope( const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; - GGML_ASSERT(mode == 0); + const int n_ctx = ((int32_t *) src1->data)[3]; const float theta_scale = powf(10000.0, -2.0f/n_dims); const float p = ((mode & 1) == 0 ? n_past + i02 : i02); + bool is_glm = mode & 4; + // compute - rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); + if (is_glm) { + const float id_p = min(p, n_ctx - 2.f); + const float block_p = max(p - (n_ctx - 2.f), 0.f); + rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); + } else { + rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); + } (void) dst; (void) src0_ddq_i;