]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fix MMQ for non-contiguous src0, add tests (llama/10021)
authorJohannes Gäßler <redacted>
Thu, 24 Oct 2024 09:09:36 +0000 (11:09 +0200)
committerGeorgi Gerganov <redacted>
Fri, 1 Nov 2024 08:19:05 +0000 (10:19 +0200)
* CUDA: fix MMQ for non-contiguous src0, add tests

* revise test code

ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml.c

index fa280b529bcb8d9546358cb00a024b412ceafad8..4a0329a639624174fdf5ae954e8ae988d867faed 100644 (file)
@@ -1151,8 +1151,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
 
     GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
-    char * src_ptr = (char *) src->data;
-    char * dst_ptr = (char *) dst;
+    const char * src_ptr = (const char *) src->data;
+    char       * dst_ptr = (char       *) dst;
 
     const int64_t ne0 = src->ne[0];
     const int64_t nb0 = src->nb[0];
@@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     const enum ggml_type type = src->type;
     const int64_t ts = ggml_type_size(type);
     const int64_t bs = ggml_blck_size(type);
-    int64_t i1_diff = i1_high - i1_low;
+    const int64_t i1_diff = i1_high - i1_low;
 
     const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
     if (nb0 == ts && nb1 == ts*ne0/bs) {
@@ -1479,13 +1479,17 @@ static void ggml_cuda_op_mul_mat(
         if (src0_is_contiguous) {
             dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
         } else {
-            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
+            // If src0 is not contiguous it will be copied to a temporary buffer, it may then be necessary to clear padding.
+            const size_t nbytes_data    = ggml_nbytes(src0);
+            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
+            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
         }
 
-        // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
+        // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
         if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
-            const int64_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
-            const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+            const size_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
+            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
             CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
         }
 
index 4935f8818679fde3fa0e7d59812bc5c21078728b..ae5c68ab3512901a902defdb8d019f7c3e4d916e 100644 (file)
@@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q(
 
     const int64_t ne00 = src0->ne[0];
 
-    const int64_t nb01 = src0->nb[1];
-
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
     GGML_ASSERT(ne10 % QK8_1 == 0);
@@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t ne0 = dst->ne[0];
 
     const int64_t row_diff = row_high - row_low;
-    const int64_t stride00 = nb01 / ggml_type_size(src0->type);
+    const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
 
     int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
index 9e428001a36cbec083d656fbbf2a519ee5d2ef19..a4359e7dd05affff77b7578b037d33804925e369 100644 (file)
@@ -3464,7 +3464,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
 
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     size_t nbytes;
-    size_t blck_size = ggml_blck_size(tensor->type);
+    const size_t blck_size = ggml_blck_size(tensor->type);
     if (blck_size == 1) {
         nbytes = ggml_type_size(tensor->type);
         for (int i = 0; i < GGML_MAX_DIMS; ++i) {