]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fix repack work size for mul_mat_id (llama/14292)
authorGeorgi Gerganov <redacted>
Fri, 20 Jun 2025 08:19:15 +0000 (11:19 +0300)
committerGeorgi Gerganov <redacted>
Sat, 21 Jun 2025 04:34:17 +0000 (07:34 +0300)
ggml-ci

ggml/src/ggml-cpu/repack.cpp

index 5c6715d5c01ea1b0e075fa913dd8c91678dfbaee..2907192904a726994eea9038e25fa2b308b874b6 100644 (file)
@@ -1163,13 +1163,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
         // not realy a GGML_TYPE_Q8_0 but same size.
         switch (op->op) {
             case GGML_OP_MUL_MAT:
-                size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
-                return true;
+                {
+                    size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+                    return true;
+                }
             case GGML_OP_MUL_MAT_ID:
-                size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
-                size = GGML_PAD(size, sizeof(int64_t));  // + padding for next bloc.
-                size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
-                return true;
+                {
+                    size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
+                    size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
+
+                    const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
+                    const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
+
+                    const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
+
+                    size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
+
+                    return true;
+                }
             default:
                 // GGML_ABORT("fatal error");
                 break;
@@ -1305,14 +1316,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
             int32_t i2;
         };
 
-        GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
-                                      n_as * ne12 * sizeof(mmid_row_mapping)));
+        GGML_ASSERT(params->wsize >=
+                (GGML_PAD(nbw3, sizeof(int64_t)) +
+                 n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
+                );
 
-        auto * wdata             = (char *)     params->wdata;
-        auto * wdata_src1_end    = (char *)     wdata + GGML_PAD(nbw3, sizeof(int64_t));
-        auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+        auto * wdata          = (char *)params->wdata;
+        auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
 
-        struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as);  // [n_as][ne12]
+        // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
+        auto * matrix_row_counts = (int64_t *) (wdata_src1_end);                                        // [n_as]
+        struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
 
         // src1: float32 => param type
         for (int64_t i12 = 0; i12 < ne12; ++i12) {