]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix logic for clearing padding with -ngl 0 (#13320)
authorJohannes Gäßler <redacted>
Mon, 5 May 2025 20:32:13 +0000 (22:32 +0200)
committerGitHub <redacted>
Mon, 5 May 2025 20:32:13 +0000 (22:32 +0200)
ggml/include/ggml-backend.h
ggml/src/ggml-backend.cpp
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmvq.cu
ggml/src/ggml-cuda/quantize.cu

index 64671495b3802c48434c4804d1f37d027a11c455..ea2c1a402cca102f5c4b3efd9e8a9b0e7593b443 100644 (file)
@@ -38,7 +38,7 @@ extern "C" {
     GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer  (ggml_backend_buffer_type_t buft, size_t size);
     GGML_API size_t                ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
     GGML_API size_t                ggml_backend_buft_get_max_size  (ggml_backend_buffer_type_t buft);
-    GGML_API size_t                ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API size_t                ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
     GGML_API bool                  ggml_backend_buft_is_host       (ggml_backend_buffer_type_t buft);
     GGML_API ggml_backend_dev_t    ggml_backend_buft_get_device    (ggml_backend_buffer_type_t buft);
 
@@ -59,7 +59,7 @@ extern "C" {
     GGML_API enum ggml_status               ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
     GGML_API size_t                         ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
     GGML_API size_t                         ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
-    GGML_API size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor);
     GGML_API void                           ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
     GGML_API bool                           ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
     GGML_API void                           ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
index 273075f4e5455af66b4ee834ba563b71f9a20604..c36b5abfb74224cfdfcd6d3ec9a1c4eb612aee78 100644 (file)
@@ -56,7 +56,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
     return SIZE_MAX;
 }
 
-size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
+size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
     // get_alloc_size is optional, defaults to ggml_nbytes
     if (buft->iface.get_alloc_size) {
         size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -152,7 +152,7 @@ size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
     return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
 }
 
-size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor) {
     return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
 }
 
index 9fb2134f98d3d9b1dc3288b11ec832d196048fcd..0d9ee0a26ddd0067b1bb98375e4c6c8e66e73f82 100644 (file)
@@ -555,8 +555,8 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer
 
     if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
         // initialize padding to 0 to avoid possible NaN values
-        size_t original_size = ggml_nbytes(tensor);
-        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+        const size_t original_size = ggml_nbytes(tensor);
+        const size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
 
         if (padded_size > original_size) {
             ggml_cuda_set_device(ctx->device);
@@ -679,6 +679,7 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
 
     if (ggml_is_quantized(tensor->type)) {
         if (ne0 % MATRIX_ROW_PADDING != 0) {
+            GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor));
             size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
         }
     }
@@ -800,6 +801,7 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff
 
 static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
     GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
+    GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
 
     ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
     ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
@@ -851,6 +853,7 @@ static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buff
     // split tensors must always be set in their entirety at once
     GGML_ASSERT(offset == 0);
     GGML_ASSERT(size == ggml_nbytes(tensor));
+    GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
 
     ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
 
@@ -889,6 +892,7 @@ static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buff
     // split tensors must always be set in their entirety at once
     GGML_ASSERT(offset == 0);
     GGML_ASSERT(size == ggml_nbytes(tensor));
+    GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
 
     ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
 
@@ -970,6 +974,7 @@ static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buf
 
 static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
     ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
+    GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
 
     size_t total_size = 0;
 
@@ -2065,6 +2070,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
         src0_slice.ne[2] = 1;
         src0_slice.nb[3] = src0_slice.nb[2];
         src0_slice.data  = (char *) src0->data + i02*nb02;
+        GGML_ASSERT(!ggml_cuda_should_use_mmq(src0->type, cc, ne11) || ne00 % MATRIX_ROW_PADDING == 0);
 
         ggml_tensor src1_slice;
         memset(&src1_slice, 0, sizeof(src1_slice));
index f397a7e038469648ef8e40e1fddad8bfb97a1378..4ccda88630a85881eb6054ee933d88deb42fd545 100644 (file)
@@ -89,6 +89,16 @@ void ggml_cuda_mul_mat_q(
     const float * src1_d = (const float *) src1->data;
     float       *  dst_d = (float       *)  dst->data;
 
+    // If src0 is a temporary compute buffer, clear any potential padding.
+    if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
+        const size_t size_data  = ggml_nbytes(src0);
+        const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
+        if (size_alloc > size_data) {
+            CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
+        }
+    }
+
     const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
 
     const int64_t s01 = src0->nb[1] / ts_src0;
index 132c466fd1aa6bc2d22311b2b7ee611aa9f8aebd..4bb51d27e434acf012d4985835b7105c9d52e25d 100644 (file)
@@ -513,6 +513,16 @@ void ggml_cuda_mul_mat_vec_q(
     const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
     float         *  dst_d =       (float         *)  dst->data;
 
+    // If src0 is a temporary compute buffer, clear any potential padding.
+    if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
+        const size_t size_data  = ggml_nbytes(src0);
+        const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
+        if (size_alloc > size_data) {
+            CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
+        }
+    }
+
     const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
     ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
     {
index 931a45ad347dcca02717191a2cd37d0e799c73c1..cb93181455d47f57ac9a10232b7a1e60d2c74380 100644 (file)
@@ -163,6 +163,7 @@ void quantize_mmq_q8_1_cuda(
         const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
         const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
         const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
+    GGML_ASSERT(ne00 % 4 == 0);
     GGML_ASSERT(ne0 % (4*QK8_1) == 0);
 
     const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);