]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
whisper: validate get_rows support for cpu extra buffer (#3323)
authorCharles Xu <redacted>
Mon, 14 Jul 2025 12:13:44 +0000 (14:13 +0200)
committerGitHub <redacted>
Mon, 14 Jul 2025 12:13:44 +0000 (15:13 +0300)
src/whisper.cpp

index 347cc178ee721a5c8cb9f50e60487d25af73a668..5c08478aefdfd7c3e823fd6e4d9ddbbec6fe9a04 100644 (file)
@@ -1438,7 +1438,8 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
         op_supported = true;
     } else {
         switch (op) {
-            // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
+            // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS
+            case GGML_OP_GET_ROWS:
             case GGML_OP_MUL_MAT: {
                 ggml_init_params params = {
                     /*.mem_size   =*/ 2 * ggml_tensor_overhead(),
@@ -1454,9 +1455,15 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
 
                 ggml_tensor * op_tensor = nullptr;
 
-                int64_t n_ctx = hparams.n_audio_ctx;
-                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
-                op_tensor = ggml_mul_mat(ctx, w, b);
+                if (op == GGML_OP_MUL_MAT) {
+                    int64_t n_ctx = hparams.n_audio_ctx;
+                    ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
+                    op_tensor = ggml_mul_mat(ctx, w, b);
+                } else if (op == GGML_OP_GET_ROWS) {
+                    int64_t num_indices = 8;
+                    ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
+                    op_tensor = ggml_get_rows(ctx, w, indices);
+                }
 
                 // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
                 GGML_ASSERT(w->buffer == nullptr);