]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : add alibi op (#457)
authorJiahao Li <redacted>
Tue, 22 Aug 2023 09:33:27 +0000 (17:33 +0800)
committerGitHub <redacted>
Tue, 22 Aug 2023 09:33:27 +0000 (12:33 +0300)
src/ggml-cuda.cu

index 5b415c646e8c6fc87c6522bfb85a48f9573baa6b..c0fb9fb650e0d22b04a444ef9056b8129f0ac392 100644 (file)
@@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
@@ -3940,6 +3941,29 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
     dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
 }
 
+static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
+                                 const int n_heads_log2_floor, const float m0, const float m1) {
+    const int col = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i = row*ncols + col;
+
+    const int k = row/k_rows;
+
+    float m_k;
+    if (k < n_heads_log2_floor) {
+        m_k = powf(m0, k + 1);
+    } else {
+        m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+    }
+
+    dst[i] = col * m_k + x[i];
+}
+
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int row = blockDim.y*blockIdx.y + threadIdx.y;
@@ -4766,6 +4790,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con
     rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
 }
 
+static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
+                           const int k_rows, const int n_heads_log2_floor, const float m0,
+                           const float m1, cudaStream_t stream) {
+    const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
+    const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
+    const dim3 block_nums(num_blocks_x, nrows, 1);
+    alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
+}
+
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
     const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -5501,6 +5534,41 @@ inline void ggml_cuda_op_rope(
     (void) i1;
 }
 
+inline void ggml_cuda_op_alibi(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_head = ((int32_t *) dst->op_params)[1];
+    float max_bias;
+    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
+
+    GGML_ASSERT(ne01 + n_past == ne00);
+    GGML_ASSERT(n_head == ne02);
+
+    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
+
+    // compute
+    alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main);
+
+    (void) src1;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i1;
+}
+
 inline void ggml_cuda_op_diag_mask_inf(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -6121,6 +6189,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
 }
 
+void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true);
+}
+
 void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -6456,6 +6529,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_rope;
             break;
+        case GGML_OP_ALIBI:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_alibi;
+            break;
         default:
             return false;
     }