]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : add support for Qwen2-Audio and SeaLLM-Audio (#13760)
authorXuan-Son Nguyen <redacted>
Sun, 25 May 2025 12:06:32 +0000 (14:06 +0200)
committerGitHub <redacted>
Sun, 25 May 2025 12:06:32 +0000 (14:06 +0200)
* mtmd : add Qwen2-Audio support

* small clean up

* update discussion link

* clarify mtmd_get_output_embd

* clarification in multimodal.md

* fix ultravox bug

* ggml_cont

convert_hf_to_gguf.py
docs/multimodal.md
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp
tools/mtmd/clip.h
tools/mtmd/mtmd.cpp
tools/mtmd/mtmd.h

index 123083b915412a7d067fc5306bb548a4369df72f..91af508a2fb28df75aee78354e70d91a28be8dae 100755 (executable)
@@ -2643,7 +2643,7 @@ class QwenModel(TextModel):
         self.gguf_writer.add_file_type(self.ftype)
 
 
-@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
+@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration")
 class Qwen2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.QWEN2
 
@@ -2667,8 +2667,9 @@ class Qwen2Model(TextModel):
             name = f"model.{name}"  # map to Qwen2ForCausalLM tensors
         if "language_model." in name:
             name = name.replace("language_model.", "") # for InternVL
-        if name.startswith("mlp") or name.startswith("vision_model"):
-            # skip visual tensors
+        if name.startswith("mlp") or name.startswith("multi_modal_projector") \
+                or name.startswith("vision_model") or name.startswith("audio_tower"):
+            # skip vision and audio tensors
             return []
         yield from super().modify_tensors(data_torch, name, bid)
 
@@ -5993,11 +5994,11 @@ class UltravoxModel(TextModel):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        raise NotImplementedError("Ultravox does not have text decoder. Please use --mmproj argument")
+        raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument")
 
 
-@ModelBase.register("UltravoxModel")
-class UltravoxAudioModel(MmprojModel):
+@ModelBase.register("Qwen2AudioForConditionalGeneration")
+class WhisperEncoderModel(MmprojModel):
     has_vision_encoder = False # no vision encoder
     has_audio_encoder = True
 
@@ -6009,10 +6010,9 @@ class UltravoxAudioModel(MmprojModel):
 
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
-        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
+        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2A)
         self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
         self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
-        self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
 
     def tensor_force_quant(self, name, new_name, bid, n_dims):
         del bid, new_name, n_dims  # unused
@@ -6023,6 +6023,10 @@ class UltravoxAudioModel(MmprojModel):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
 
+        if name.startswith("language_model."):
+            # skip language model tensors
+            return []
+
         # prevent clash naming with vision tensors
         if name.startswith("multi_modal_projector"):
             name = "audio." + name
@@ -6033,6 +6037,16 @@ class UltravoxAudioModel(MmprojModel):
 
         return [(self.map_tensor_name(name), data_torch)]
 
+
+@ModelBase.register("UltravoxModel")
+class UltravoxWhisperEncoderModel(WhisperEncoderModel):
+    has_vision_encoder = False # no vision encoder
+    has_audio_encoder = True
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
+
 ###### CONVERSION LOGIC ######
 
 
index 2f3c416b65e7d078d41c3c199dad6046f59f3cfb..3a0994a279ae87137698bb6251f12956368bef46 100644 (file)
@@ -93,4 +93,8 @@ NOTE: some models may require large context window, for example: `-c 8192`
 # Ultravox 0.5
 (tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
 (tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
+
+# Qwen2-Audio and SeaLLM-Audio
+# note: no pre-quantized GGUF this model, as they have very poor result
+# ref: https://github.com/ggml-org/llama.cpp/pull/13760
 ```
index 58de45dfddb85b5b9f6c034da613f3fa3bee61f8..c6255d6867a1505a147722e240a1e76884d264ac 100644 (file)
@@ -546,6 +546,7 @@ class MODEL_TENSOR(IntEnum):
     A_ENC_FFN_GATE       = auto()
     A_ENC_FFN_DOWN       = auto()
     A_MMPROJ             = auto()
+    A_MMPROJ_FC          = auto()
     A_MM_NORM_PRE        = auto()
     A_MM_NORM_MID        = auto()
 
@@ -825,6 +826,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.A_ENC_FFN_GATE:            "a.blk.{bid}.ffn_gate",
     MODEL_TENSOR.A_ENC_FFN_DOWN:            "a.blk.{bid}.ffn_down",
     MODEL_TENSOR.A_MMPROJ:                  "mm.a.mlp.{bid}",
+    MODEL_TENSOR.A_MMPROJ_FC:               "mm.a.fc",
     MODEL_TENSOR.A_MM_NORM_PRE:             "mm.a.norm_pre",
     MODEL_TENSOR.A_MM_NORM_MID:             "mm.a.norm_mid",
 }
@@ -885,6 +887,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.A_ENC_FFN_GATE,
         MODEL_TENSOR.A_ENC_FFN_DOWN,
         MODEL_TENSOR.A_MMPROJ,
+        MODEL_TENSOR.A_MMPROJ_FC,
         MODEL_TENSOR.A_MM_NORM_PRE,
         MODEL_TENSOR.A_MM_NORM_MID,
     ],
@@ -2256,6 +2259,7 @@ class VisionProjectorType:
     QWEN25VL = "qwen2.5vl_merger"
     ULTRAVOX = "ultravox"
     INTERNVL = "internvl"
+    QWEN2A = "qwen2a" # audio
 
 
 # Items here are (block size, type size)
index 91a95ea48b42d3e071c969dc9c21d4a206072232..4a0615b656812304645cb940135f15da6949dad1 100644 (file)
@@ -1165,6 +1165,10 @@ class TensorNameMap:
             "audio.multi_modal_projector.linear_{bid}", # ultravox
         ),
 
+        MODEL_TENSOR.A_MMPROJ_FC: (
+            "audio.multi_modal_projector.linear", # qwen2audio
+        ),
+
         MODEL_TENSOR.A_MM_NORM_PRE: (
             "audio.multi_modal_projector.ln_pre", # ultravox
         ),
index 15ec3db906477cdab5b4645e690d588c518a2cc7..27ce8c43f678ccdea992e4292712792f6c2ef3f3 100644 (file)
 // ultravox
 #define TN_CONV1D       "a.conv1d.%d.%s"
 #define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
+#define TN_MM_AUDIO_FC  "mm.a.fc.%s" // fully connected layer
 #define TN_MM_NORM_PRE  "mm.a.norm_pre.%s"
 #define TN_MM_NORM_MID  "mm.a.norm_mid.%s"
 
@@ -128,6 +129,7 @@ enum projector_type {
     PROJECTOR_TYPE_ULTRAVOX,
     PROJECTOR_TYPE_INTERNVL,
     PROJECTOR_TYPE_LLAMA4,
+    PROJECTOR_TYPE_QWEN2A,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -145,6 +147,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_ULTRAVOX,  "ultravox"},
     { PROJECTOR_TYPE_INTERNVL,  "internvl"},
     { PROJECTOR_TYPE_LLAMA4,    "llama4"},
+    { PROJECTOR_TYPE_QWEN2A,    "qwen2a"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
index 34c4ce84275fa0bc4e24de59794ed55448012f5d..6205dad5ae262b5435fc0df569b1574b9b4a6b77 100644 (file)
@@ -254,7 +254,9 @@ struct clip_vision_model {
     ggml_tensor * post_ln_w;
     ggml_tensor * post_ln_b;
 
-    ggml_tensor * projection;
+    ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
+    ggml_tensor * mm_fc_w;
+    ggml_tensor * mm_fc_b;
 
     // LLaVA projection
     ggml_tensor * mm_input_norm_w = nullptr;
@@ -1471,48 +1473,58 @@ struct clip_graph {
 
         cb(cur, "after_transformer", -1);
 
-        // StackAudioFrames
-        // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
-        {
-            int64_t stride = n_embd * hparams.proj_stack_factor;
-            int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
-            int64_t pad = padded_len - ggml_nelements(cur);
-            if (pad > 0) {
-                cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
-                cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
+        if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
+            // StackAudioFrames
+            // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
+            {
+                int64_t stride = n_embd * hparams.proj_stack_factor;
+                int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
+                int64_t pad = padded_len - ggml_nelements(cur);
+                if (pad > 0) {
+                    cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
+                    cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
+                }
+                cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
+                                    ggml_row_size(cur->type, stride), 0);
             }
-            cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
-                                ggml_row_size(cur->type, stride), 0);
-        }
 
-        cb(cur, "after_stacked", -1);
+            cb(cur, "after_stacked", -1);
 
-        // UltravoxProjector
-        {
-            // pre-norm
-            cur = ggml_rms_norm(ctx0, cur, 1e-6);
-            cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
+            // UltravoxProjector
+            {
+                // pre-norm
+                cur = ggml_rms_norm(ctx0, cur, 1e-6);
+                cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
 
-            // ffn in
-            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+                // ffn in
+                cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
 
-            // swiglu
-            {
-                int64_t split_point = cur->ne[0] / 2;
-                ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
-                ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+                // swiglu
+                {
+                    int64_t split_point = cur->ne[0] / 2;
+                    ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                    ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                    // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
+                    x1 = ggml_silu(ctx0, x1);
+                    cur = ggml_mul(ctx0, x0, x1);
+                }
 
-                // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
-                x1 = ggml_silu(ctx0, x1);
-                cur = ggml_mul(ctx0, x0, x1);
+                // mid-norm
+                cur = ggml_rms_norm(ctx0, cur, 1e-6);
+                cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
+
+                // ffn out
+                cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
             }
 
-            // mid-norm
-            cur = ggml_rms_norm(ctx0, cur, 1e-6);
-            cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
+        } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
+            // projector
+            cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_fc_b);
 
-            // ffn out
-            cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+        } else {
+            GGML_ABORT("%s: unknown projector type", __func__);
         }
 
         cb(cur, "projected", -1);
@@ -1655,6 +1667,17 @@ private:
             inpL = cur;
         }
 
+        // TODO @ngxson : find a way to move this outside
+        if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
+            ggml_tensor * cur = inpL;
+            cur = ggml_transpose(ctx0, cur);
+            cur = ggml_cont(ctx0, cur);
+            cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0);
+            cur = ggml_transpose(ctx0, cur);
+            cur = ggml_cont(ctx0, cur);
+            inpL = cur;
+        }
+
         // post-layernorm
         if (model.post_ln_w) {
             inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
@@ -1952,6 +1975,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
                 res = graph.build_llama4();
             } break;
         case PROJECTOR_TYPE_ULTRAVOX:
+        case PROJECTOR_TYPE_QWEN2A:
             {
                 res = graph.build_whisper_enc();
             } break;
@@ -2186,8 +2210,10 @@ struct clip_model_loader {
                         };
                     } break;
                 case PROJECTOR_TYPE_ULTRAVOX:
+                case PROJECTOR_TYPE_QWEN2A:
                     {
-                        get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor);
+                        bool require_stack = ctx_clip.proj_type == PROJECTOR_TYPE_ULTRAVOX;
+                        get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
                         if (hparams.n_mel_bins != 128) {
                             throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
                         }
@@ -2266,7 +2292,7 @@ struct clip_model_loader {
             return cur;
         };
 
-        auto & vision_model = ctx_clip.vision_model;
+        auto & vision_model = ctx_clip.vision_model; // TODO: rename this to just "model"
 
         vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
 
@@ -2463,6 +2489,15 @@ struct clip_model_loader {
                     vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
                     vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
                 } break;
+            case PROJECTOR_TYPE_QWEN2A:
+                {
+                    vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    vision_model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
+                    vision_model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
+                } break;
             case PROJECTOR_TYPE_INTERNVL:
                 {
                     vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@@ -3450,6 +3485,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
         const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
         n_patches = n_len / proj_stack_factor / 2;
+    } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
+        // divide by 2 because of whisper
+        // another divide by 2 because of nn.AvgPool1d(2, stride=2)
+        n_patches = img->nx / 4;
     }
 
     return n_patches;
@@ -3850,6 +3889,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         case PROJECTOR_TYPE_GEMMA3:
         case PROJECTOR_TYPE_IDEFICS3:
         case PROJECTOR_TYPE_INTERNVL:
+        case PROJECTOR_TYPE_QWEN2A:
         case PROJECTOR_TYPE_ULTRAVOX:
             {
                 // do nothing
@@ -3910,7 +3950,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     const int n_tokens_out = embeddings->ne[1];
     const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
     if (n_tokens_out != expected_n_tokens_out) {
-        LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
+        LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
         GGML_ABORT("Invalid number of output tokens");
     }
 
@@ -3955,6 +3995,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->vision_model.mm_3_w->ne[1];
         case PROJECTOR_TYPE_LLAMA4:
             return ctx->vision_model.mm_model_proj->ne[1];
+        case PROJECTOR_TYPE_QWEN2A:
+            return ctx->vision_model.mm_fc_w->ne[1];
         default:
             GGML_ABORT("Unknown projector type");
     }
@@ -3991,6 +4033,10 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
     return ctx->vision_model.hparams.has_audio;
 }
 
+bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
+    return ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type == PROJECTOR_TYPE_QWEN2A;
+}
+
 bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
     clip_image_f32 clip_img;
     clip_img.buf.resize(h * w * 3);
index 94267131628a683eb3b29b17a18fe053c37d164e..5abfcd1a3c418dcefaa68279f0c72741b556d2b6 100644 (file)
@@ -4,6 +4,8 @@
 #include <stddef.h>
 #include <stdint.h>
 
+// !!! Internal header, to be used by mtmd only !!!
+
 struct clip_ctx;
 
 struct clip_image_size {
@@ -99,3 +101,4 @@ void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel
 
 bool clip_has_vision_encoder(const struct clip_ctx * ctx);
 bool clip_has_audio_encoder(const struct clip_ctx * ctx);
+bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
index d3f3cf3a061de373653a008fb1090e70ae01cd3d..c3be91265f331619304fac53f066e697f518c080 100644 (file)
@@ -146,6 +146,13 @@ struct mtmd_context {
             throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
         }
 
+        if (llama_model_n_embd(text_model) != clip_n_mmproj_embd(ctx_clip)) {
+            throw std::runtime_error(string_format(
+                "mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"
+                "hint: you may be using wrong mmproj\n",
+                llama_model_n_embd(text_model), clip_n_mmproj_embd(ctx_clip)));
+        }
+
         has_vision = clip_has_vision_encoder(ctx_clip);
         has_audio  = clip_has_audio_encoder(ctx_clip);
         use_mrope  = clip_is_qwen2vl(ctx_clip);
@@ -196,7 +203,7 @@ struct mtmd_context {
             ov_img_first      = false; // overview image is last
         }
 
-        if (proj == PROJECTOR_TYPE_ULTRAVOX) {
+        if (clip_has_whisper_encoder(ctx_clip)) {
             // TODO @ngxson : check if model n_mel is 128 or 80
             w_filters = whisper_precalc_filters::get_128_bins();
         }
@@ -208,7 +215,7 @@ struct mtmd_context {
         }
         if (has_audio) {
             LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
-                    "    https://github.com/ggml-org/llama.cpp/pull/13623\n", __func__);
+                    "    https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
         }
     }
 
@@ -327,6 +334,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
         marker_modified = "<img>" + ctx->media_marker + "</img>";
         string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
 
+    } else if (proj_type == PROJECTOR_TYPE_QWEN2A) {
+        // <|audio_bos|> ... (embeddings) ... <|audio_eos|>
+        marker_modified = "<|audio_bos|>" + ctx->media_marker + "<|audio_eos|>";
+        string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
+
     }
 
     // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
index 2c722b012ea053d2daff4d9fd65e23e0e04ad375..b53f215a2fafd450536169716617000f3a97cad6 100644 (file)
@@ -203,6 +203,8 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
                                    const mtmd_input_chunk * chunk);
 
 // get output embeddings from the last encode pass
+// the reading size (in bytes) is equal to:
+// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
 MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
 
 /////////////////////////////////////////