]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-hexagon: respect input size when getting/setting tensor data (#16836)
authorl3utterfly <redacted>
Fri, 31 Oct 2025 04:46:31 +0000 (12:46 +0800)
committerGitHub <redacted>
Fri, 31 Oct 2025 04:46:31 +0000 (21:46 -0700)
* respect input size when getting/setting tensor data

allows partial repacking/copying when get tensor size is smaller than the actual tensor

* Removed duplicate repack_mxfp4_mxfp4x4x2 function

ggml/src/ggml-hexagon/ggml-hexagon.cpp

index 2d376a6025c072f6586c5f65ebb54583551df805..945652263d4810d78968c34ca7ac051b13cbd995 100644 (file)
@@ -676,6 +676,15 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
     size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));  // extra elements for the pad
     size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 
+    // Ensure we don't try to read more data than is available in the source buffer 'data'
+    // or write more than the tensor can hold.
+    const size_t total_tensor_size = (size_t)nrows * row_size;
+    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+    // Calculate how many full rows and how many remaining bytes we need to process.
+    const int64_t n_full_rows = n_bytes_to_copy / row_size;
+    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
+
     void * buf_pd = ggml_aligned_malloc(row_size_pd);
     GGML_ASSERT(buf_pd != NULL);
 
@@ -687,7 +696,8 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
 
     init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
 
-    for (int64_t i = 0; i < nrows; i++) {
+    // 1. Process all the full rows
+    for (int64_t i = 0; i < n_full_rows; i++) {
         const uint8_t * src = (const uint8_t *) data + (i * row_size);
         uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
 
@@ -696,6 +706,25 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size)
         memcpy(dst, buf_rp, row_size);
     }
 
+    // 2. Process the final, potentially partial, row
+    if (n_rem_bytes > 0) {
+        const int64_t i = n_full_rows;
+        const uint8_t * src = (const uint8_t *) data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
+
+        // re-init the row because we are potentially copying a partial row
+        init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);
+
+        // Copy only the remaining bytes from the source.
+        memcpy(buf_pd, src, n_rem_bytes);
+
+        // Repack the entire buffer
+        repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
+
+        // Write only the corresponding remaining bytes to the destination tensor.
+        memcpy(dst, buf_rp, n_rem_bytes);
+    }
+
     ggml_aligned_free(buf_pd, row_size_pd);
     ggml_aligned_free(buf_rp, row_size_rp);
 }
@@ -708,6 +737,14 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
     size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2));  // extra elements for the pad
     size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 
+    // Ensure we don't try to copy more data than the tensor actually contains.
+    const size_t total_tensor_size = (size_t)nrows * row_size;
+    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+    // Calculate how many full rows and how many remaining bytes we need to process.
+    const int64_t n_full_rows = n_bytes_to_copy / row_size;
+    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
+
     void * buf_pd = ggml_aligned_malloc(row_size_pd);
     GGML_ASSERT(buf_pd != NULL);
 
@@ -719,7 +756,8 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
 
     memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
 
-    for (int64_t i = 0; i < nrows; i++) {
+    // 1. Process all the full rows
+    for (int64_t i = 0; i < n_full_rows; i++) {
         const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
         uint8_t *       dst = (uint8_t *) data + (i * row_size);
 
@@ -728,6 +766,20 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size)
         memcpy(dst, buf_rp, row_size);
     }
 
+    // 2. Process the final, potentially partial, row
+    if (n_rem_bytes > 0) {
+        const int64_t i = n_full_rows;
+        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) data + (i * row_size);
+
+        // We still need to read and unpack the entire source row because quantization is block-based.
+        memcpy(buf_pd, src, row_size);
+        unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+
+        // But we only copy the remaining number of bytes to the destination.
+        memcpy(dst, buf_rp, n_rem_bytes);
+    }
+
     ggml_aligned_free(buf_pd, row_size_pd);
     ggml_aligned_free(buf_rp, row_size_rp);
 }
@@ -950,6 +1002,15 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
     size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2));  // extra elements for the pad
     size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 
+    // Ensure we don't try to read more data than is available in the source buffer 'data'
+    // or write more than the tensor can hold.
+    const size_t total_tensor_size = (size_t)nrows * row_size;
+    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+    // Calculate how many full rows and how many remaining bytes we need to process.
+    const int64_t n_full_rows = n_bytes_to_copy / row_size;
+    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
+
     void * buf_pd = ggml_aligned_malloc(row_size_pd);
     GGML_ASSERT(buf_pd != NULL);
 
@@ -961,7 +1022,8 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
 
     init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
 
-    for (int64_t i = 0; i < nrows; i++) {
+    // 1. Process all the full rows
+    for (int64_t i = 0; i < n_full_rows; i++) {
         const uint8_t * src = (const uint8_t *) data + (i * row_size);
         uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
 
@@ -970,6 +1032,25 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size)
         memcpy(dst, buf_rp, row_size);
     }
 
+    // 2. Process the final, potentially partial, row
+    if (n_rem_bytes > 0) {
+        const int64_t i = n_full_rows;
+        const uint8_t * src = (const uint8_t *) data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
+
+        // re-init the row because we are potentially copying a partial row
+        init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);
+
+        // Copy only the remaining bytes from the source.
+        memcpy(buf_pd, src, n_rem_bytes);
+
+        // Repack the entire buffer
+        repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
+
+        // Write only the corresponding remaining bytes to the destination tensor.
+        memcpy(dst, buf_rp, n_rem_bytes);
+    }
+
     ggml_aligned_free(buf_pd, row_size_pd);
     ggml_aligned_free(buf_rp, row_size_rp);
 }
@@ -982,6 +1063,14 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
     size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2));  // extra elements for the pad
     size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 
+    // Ensure we don't try to copy more data than the tensor actually contains.
+    const size_t total_tensor_size = (size_t)nrows * row_size;
+    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+    // Calculate how many full rows and how many remaining bytes we need to process.
+    const int64_t n_full_rows = n_bytes_to_copy / row_size;
+    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
+
     void * buf_pd = ggml_aligned_malloc(row_size_pd);
     GGML_ASSERT(buf_pd != NULL);
 
@@ -993,7 +1082,8 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
 
     memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
 
-    for (int64_t i = 0; i < nrows; i++) {
+    // 1. Process all the full rows
+    for (int64_t i = 0; i < n_full_rows; i++) {
         const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
         uint8_t *       dst = (uint8_t *) data + (i * row_size);
 
@@ -1002,6 +1092,20 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size)
         memcpy(dst, buf_rp, row_size);
     }
 
+    // 2. Process the final, potentially partial, row
+    if (n_rem_bytes > 0) {
+        const int64_t i = n_full_rows;
+        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) data + (i * row_size);
+
+        // We still need to read and unpack the entire source row because quantization is block-based.
+        memcpy(buf_pd, src, row_size);
+        unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+
+        // But we only copy the remaining number of bytes to the destination.
+        memcpy(dst, buf_rp, n_rem_bytes);
+    }
+
     ggml_aligned_free(buf_pd, row_size_pd);
     ggml_aligned_free(buf_rp, row_size_rp);
 }
@@ -1249,6 +1353,15 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
     size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2));  // extra elements for the pad
     size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 
+    // Ensure we don't try to read more data than is available in the source buffer 'data'
+    // or write more than the tensor can hold.
+    const size_t total_tensor_size = (size_t)nrows * row_size;
+    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+    // Calculate how many full rows and how many remaining bytes we need to process.
+    const int64_t n_full_rows = n_bytes_to_copy / row_size;
+    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
+
     void * buf_pd = ggml_aligned_malloc(row_size_pd);
     GGML_ASSERT(buf_pd != NULL);
 
@@ -1260,7 +1373,8 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
 
     init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);  // init padded buffer to make sure the tail is all zeros
 
-    for (int64_t i = 0; i < nrows; i++) {
+    // 1. Process all the full rows
+    for (int64_t i = 0; i < n_full_rows; i++) {
         const uint8_t * src = (const uint8_t *) data + (i * row_size);
         uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
 
@@ -1269,6 +1383,25 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si
         memcpy(dst, buf_rp, row_size);
     }
 
+    // 2. Process the final, potentially partial, row
+    if (n_rem_bytes > 0) {
+        const int64_t i = n_full_rows;
+        const uint8_t * src = (const uint8_t *) data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) t->data + (i * row_size);
+
+        // re-init the row because we are potentially copying a partial row
+        init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);
+
+        // Copy only the remaining bytes from the source.
+        memcpy(buf_pd, src, n_rem_bytes);
+
+        // Repack the entire buffer (partial data + zero padding).
+        repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
+
+        // Write only the corresponding remaining bytes to the destination tensor.
+        memcpy(dst, buf_rp, n_rem_bytes);
+    }
+
     ggml_aligned_free(buf_pd, row_size_pd);
     ggml_aligned_free(buf_rp, row_size_rp);
 }
@@ -1281,6 +1414,14 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
     size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2));  // extra elements for the pad
     size_t row_size_rp = row_size * 2;  // extra space for tmp pad (if any)
 
+    // Ensure we don't try to copy more data than the tensor actually contains.
+    const size_t total_tensor_size = (size_t)nrows * row_size;
+    const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
+
+    // Calculate how many full rows and how many remaining bytes we need to process.
+    const int64_t n_full_rows = n_bytes_to_copy / row_size;
+    const size_t  n_rem_bytes = n_bytes_to_copy % row_size;
+
     void * buf_pd = ggml_aligned_malloc(row_size_pd);
     GGML_ASSERT(buf_pd != NULL);
 
@@ -1292,7 +1433,8 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
 
     memset(buf_pd, 0, row_size_pd);  // clear-out padded buffer to make sure the tail is all zeros
 
-    for (int64_t i = 0; i < nrows; i++) {
+    // 1. Process all the full rows
+    for (int64_t i = 0; i < n_full_rows; i++) {
         const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
         uint8_t *       dst = (uint8_t *) data + (i * row_size);
 
@@ -1301,6 +1443,20 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si
         memcpy(dst, buf_rp, row_size);
     }
 
+    // 2. Process the final, potentially partial, row
+    if (n_rem_bytes > 0) {
+        const int64_t i = n_full_rows;
+        const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
+        uint8_t *       dst = (uint8_t *) data + (i * row_size);
+
+        // We still need to read and unpack the entire source row because the format is block-based.
+        memcpy(buf_pd, src, row_size);
+        unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
+
+        // But we only copy the remaining number of bytes to the destination to respect the size limit.
+        memcpy(dst, buf_rp, n_rem_bytes);
+    }
+
     ggml_aligned_free(buf_pd, row_size_pd);
     ggml_aligned_free(buf_rp, row_size_rp);
 }
@@ -1319,19 +1475,19 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
     switch (tensor->type) {
         case GGML_TYPE_Q4_0:
             GGML_ASSERT(offset == 0);
-            GGML_ASSERT(size == ggml_nbytes(tensor));
+            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
             repack_q4_0_q4x4x2(tensor, data, size);
             break;
 
         case GGML_TYPE_Q8_0:
             GGML_ASSERT(offset == 0);
-            GGML_ASSERT(size == ggml_nbytes(tensor));
+            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
             repack_q8_0_q8x4x2(tensor, data, size);
             break;
 
         case GGML_TYPE_MXFP4:
             GGML_ASSERT(offset == 0);
-            GGML_ASSERT(size == ggml_nbytes(tensor));
+            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
             repack_mxfp4_mxfp4x4x2(tensor, data, size);
             break;
 
@@ -1355,19 +1511,19 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
     switch (tensor->type) {
         case GGML_TYPE_Q4_0:
             GGML_ASSERT(offset == 0);
-            GGML_ASSERT(size == ggml_nbytes(tensor));
+            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
             repack_q4x4x2_q4_0(data, tensor, size);
             break;
 
         case GGML_TYPE_Q8_0:
             GGML_ASSERT(offset == 0);
-            GGML_ASSERT(size == ggml_nbytes(tensor));
+            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
             repack_q8x4x2_q8_0(data, tensor, size);
             break;
 
         case GGML_TYPE_MXFP4:
             GGML_ASSERT(offset == 0);
-            GGML_ASSERT(size == ggml_nbytes(tensor));
+            GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
             repack_mxfp4x4x2_mxfp4(data, tensor, size);
             break;