]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : sync (ggml_conv_2d, fix mul_mat bug, CUDA GLM rope)
authorGeorgi Gerganov <redacted>
Fri, 14 Jul 2023 13:36:41 +0000 (16:36 +0300)
committerGeorgi Gerganov <redacted>
Fri, 14 Jul 2023 13:36:41 +0000 (16:36 +0300)
ggml-cuda.cu
ggml.c

index e0d5e9156315ba4ed74e385c819e75fe614d3652..920466aae72db8f01c44e759a3c7c6a0b856cf1f 100644 (file)
@@ -1667,6 +1667,40 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
     dst[i + 1] = x0*sin_theta + x1*cos_theta;
 }
 
+static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
+    const int col = blockDim.x*blockIdx.x + threadIdx.x;
+    const int half_n_dims = ncols/4;
+
+    if (col >= half_n_dims) {
+        return;
+    }
+
+    const int row = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i = row*ncols + col;
+
+    const float col_theta_scale = powf(theta_scale, col);
+
+    const float theta = p*col_theta_scale;
+    const float sin_theta = sinf(theta);
+    const float cos_theta = cosf(theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + half_n_dims];
+
+    dst[i + 0]           = x0*cos_theta - x1*sin_theta;
+    dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
+
+    const float block_theta = block_p*col_theta_scale;
+    const float sin_block_theta = sinf(block_theta);
+    const float cos_block_theta = cosf(block_theta);
+
+    const float x2 = x[i + half_n_dims * 2];
+    const float x3 = x[i + half_n_dims * 3];
+
+    dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
+    dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
+}
+
 static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
     const int col = blockDim.x*blockIdx.x + threadIdx.x;
     const int row = blockDim.y*blockIdx.y + threadIdx.y;
@@ -2064,6 +2098,14 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
 }
 
+static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
+    GGML_ASSERT(nrows % 4 == 0);
+    const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
+    const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(num_blocks_x, nrows, 1);
+    rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
+}
+
 static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
     const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -2618,13 +2660,21 @@ inline void ggml_cuda_op_rope(
     const int n_past = ((int32_t *) src1->data)[0];
     const int n_dims = ((int32_t *) src1->data)[1];
     const int mode   = ((int32_t *) src1->data)[2];
-    GGML_ASSERT(mode == 0);
+    const int n_ctx  = ((int32_t *) src1->data)[3];
 
     const float theta_scale = powf(10000.0, -2.0f/n_dims);
     const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
 
+    bool is_glm = mode & 4;
+
     // compute
-    rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+    if (is_glm) {
+        const float id_p = min(p, n_ctx - 2.f);
+        const float block_p = max(p - (n_ctx - 2.f), 0.f);
+        rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
+    } else {
+        rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
+    }
 
     (void) dst;
     (void) src0_ddq_i;
diff --git a/ggml.c b/ggml.c
index c137ae658df7f7bf97e74f85cdca55c0ab8cf7b5..f5821f1f1aa042d3c04eabc36b9aca9678ea4e87 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -10684,6 +10684,8 @@ static void ggml_compute_forward_mul_mat(
 
     const enum ggml_type type = src0->type;
 
+    const bool src1_cont = ggml_is_contiguous(src1);
+
     ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
     enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
     ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
@@ -10747,7 +10749,7 @@ static void ggml_compute_forward_mul_mat(
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 if (type != GGML_TYPE_F32) {
-                    float * const wdata = params->wdata;
+                            float * const wdata    = params->wdata;
                     ggml_to_float_t const to_float = type_traits[type].to_float;
 
                     size_t id = 0;
@@ -10805,7 +10807,7 @@ static void ggml_compute_forward_mul_mat(
     // src1 rows
     const int64_t nr1 = ne11*ne12*ne13;
 
-    void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+    const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
     const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
 
     for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
@@ -10828,7 +10830,15 @@ static void ggml_compute_forward_mul_mat(
         const int64_t i3 = i13;
 
         const char * src0_row = (const char *) src0->data + (  0 + i02*nb02 + i03*nb03     );
-        const char * src1_col = (const char *)      wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size;
+
+        // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+        //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+        //       the original src1 data pointer, so we should index using the indices directly
+        // TODO: this is a bit of a hack, we should probably have a better way to handle this
+        const char * src1_col = (const char *) wdata +
+            (src1_cont || src1->type != vec_dot_type
+             ? (i11      + i12*ne11 + i13*ne12*ne11)*row_size
+             : (i11*nb11 + i12*nb12 + i13*nb13));
 
         float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
 
@@ -12982,12 +12992,13 @@ static void ggml_compute_forward_conv_1d(
     };
 }
 
-// ggml_compute_forward_conv_2d_sk_p0
+// ggml_compute_forward_conv_2d
 
-static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
+static void ggml_compute_forward_conv_2d_f16_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
+        const struct ggml_tensor * opt0,
               struct ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13007,28 +13018,37 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
     // size of the convolution row - the kernel size unrolled across all channels
     const int ew0 = nk0*nk1*ne02;
 
+    const int32_t s0 = ((const int32_t*)(opt0->data))[0];
+    const int32_t s1 = ((const int32_t*)(opt0->data))[1];
+    const int32_t p0 = ((const int32_t*)(opt0->data))[2];
+    const int32_t p1 = ((const int32_t*)(opt0->data))[3];
+    const int32_t d0 = ((const int32_t*)(opt0->data))[4];
+    const int32_t d1 = ((const int32_t*)(opt0->data))[5];
+
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
     GGML_ASSERT(nb10 == sizeof(float));
 
     if (params->type == GGML_TASK_INIT) {
-        // TODO: fix this memset (wsize is overestimated)
         memset(params->wdata, 0, params->wsize);
 
         // prepare source data (src1)
         {
             ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 
-            for (int i13 = 0; i13 < ne13; i13++) {
-                for (int i12 = 0; i12 < ne12; i12++) {
-                    const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
-                    ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
+            for (int i12 = 0; i12 < ne12; i12++) {
+                const float * const src = (float *)((char *) src1->data + i12*nb12);
+                ggml_fp16_t * dst_data = wdata;
 
-                    for (int i1 = 0; i1 < ne1; i1++) {
-                        for (int i0 = 0; i0 < ne0; i0++) {
-                            for (int ik1 = 0; ik1 < nk1; ik1++) {
-                                for (int ik0 = 0; ik0 < nk0; ik0++) {
+                for (int i1 = 0; i1 < ne1; i1++) {
+                    for (int i0 = 0; i0 < ne0; i0++) {
+                        for (int ik1 = 0; ik1 < nk1; ik1++) {
+                            for (int ik0 = 0; ik0 < nk0; ik0++) {
+                                const int idx0 = i0*s0 + ik0*d0 - p0;
+                                const int idx1 = i1*s1 + ik1*d1 - p1;
+
+                                if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
                                     dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
-                                        GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
+                                        GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
                                 }
                             }
                         }
@@ -13071,19 +13091,21 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
     }
 }
 
-static void ggml_compute_forward_conv_2d_sk_p0(
+static void ggml_compute_forward_conv_2d(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
+        const struct ggml_tensor * opt0,
+        struct ggml_tensor * dst
+        ) {
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
+                ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst);
             } break;
         case GGML_TYPE_F32:
             {
-                //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
+                //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst);
                 GGML_ASSERT(false);
             } break;
         default:
@@ -13093,32 +13115,6 @@ static void ggml_compute_forward_conv_2d_sk_p0(
     }
 }
 
-// ggml_compute_forward_conv_2d
-
-static void ggml_compute_forward_conv_2d(
-    const struct ggml_compute_params* params,
-    const struct ggml_tensor* src0,
-    const struct ggml_tensor* src1,
-    const struct ggml_tensor* opt0,
-    struct ggml_tensor* dst) {
-    const int32_t s0 = ((const int32_t*)(opt0->data))[0];
-    const int32_t s1 = ((const int32_t*)(opt0->data))[1];
-    const int32_t p0 = ((const int32_t*)(opt0->data))[2];
-    const int32_t p1 = ((const int32_t*)(opt0->data))[3];
-    const int32_t d0 = ((const int32_t*)(opt0->data))[4];
-    const int32_t d1 = ((const int32_t*)(opt0->data))[5];
-    GGML_ASSERT(d0 == 1); // dilation not supported
-    GGML_ASSERT(d1 == 1);
-    GGML_ASSERT(p0 == 0); // padding not supported
-    GGML_ASSERT(p1 == 0);
-
-    if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
-        ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
-    } else {
-        GGML_ASSERT(false); // only stride equal to kernel size is supported
-    }
-}
-
 // ggml_compute_forward_pool_1d_sk_p0
 
 static void ggml_compute_forward_pool_1d_sk_p0(
@@ -16575,19 +16571,22 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     const int64_t ne11 = node->src[1]->ne[1]; // H
                     const int64_t ne12 = node->src[1]->ne[2]; // C
 
+                    const int64_t ne0 = node->ne[0];
+                    const int64_t ne1 = node->ne[1];
+                    const int64_t ne2 = node->ne[2];
                     const int64_t nk = ne00*ne01;
+                    const int64_t ew0 = nk * ne02;
 
-                    UNUSED(ne02);
                     UNUSED(ne03);
-                    UNUSED(nk);
+                    UNUSED(ne2);
 
                     size_t cur = 0;
 
                     if (node->src[0]->type == GGML_TYPE_F16 &&
-                            node->src[1]->type == GGML_TYPE_F32) {
-                        cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
+                        node->src[1]->type == GGML_TYPE_F32) {
+                        cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
                     } else if (node->src[0]->type == GGML_TYPE_F32 &&
-                            node->src[1]->type == GGML_TYPE_F32) {
+                               node->src[1]->type == GGML_TYPE_F32) {
                         cur = sizeof(float)*      (ne10*ne11*ne12);
                     } else {
                         GGML_ASSERT(false);