From: Jiahao Li Date: Tue, 22 Aug 2023 09:33:27 +0000 (+0800) Subject: cuda : add alibi op (#457) X-Git-Tag: upstream/0.0.1642~1271 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=2f36aa45608c8480a55c23eb2acd11eee3622a8d;p=pkg%2Fggml%2Fsources%2Fggml cuda : add alibi op (#457) --- diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 5b415c64..c0fb9fb6 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 +#define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 #define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 @@ -3940,6 +3941,29 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; } +static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, + const int n_heads_log2_floor, const float m0, const float m1) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int i = row*ncols + col; + + const int k = row/k_rows; + + float m_k; + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + dst[i] = col * m_k + x[i]; +} + static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int row = blockDim.y*blockIdx.y + threadIdx.y; @@ -4766,6 +4790,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con rope_glm_f32<<>>(x, dst, ncols, p, block_p, theta_scale); } +static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, + const int k_rows, const int n_heads_log2_floor, const float m0, + const float m1, cudaStream_t stream) { + const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); + const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); + const dim3 block_nums(num_blocks_x, nrows, 1); + alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); +} + static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; @@ -5501,6 +5534,41 @@ inline void ggml_cuda_op_rope( (void) i1; } +inline void ggml_cuda_op_alibi( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, + float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t & cudaStream_main){ + + GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(dst_ddf_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t i01_diff = i01_high - i01_low; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + GGML_ASSERT(ne01 + n_past == ne00); + GGML_ASSERT(n_head == ne02); + + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + // compute + alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main); + + (void) src1; + (void) src0_ddq_i; + (void) src1_ddf_i; + (void) i1; +} + inline void ggml_cuda_op_diag_mask_inf( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, @@ -6121,6 +6189,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm } +void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true); +} + void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; @@ -6456,6 +6529,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_rope; break; + case GGML_OP_ALIBI: + if (!any_on_device) { + return false; + } + func = ggml_cuda_alibi; + break; default: return false; }