]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : support audio input (#13714)
authorXuan-Son Nguyen <redacted>
Fri, 23 May 2025 09:03:47 +0000 (11:03 +0200)
committerGitHub <redacted>
Fri, 23 May 2025 09:03:47 +0000 (11:03 +0200)
* server : support audio input

* add audio support on webui

12 files changed:
tools/mtmd/mtmd-helper.cpp
tools/mtmd/mtmd.cpp
tools/mtmd/mtmd.h
tools/server/public/index.html.gz
tools/server/server.cpp
tools/server/tests/unit/test_vision_api.py
tools/server/utils.hpp
tools/server/webui/src/components/ChatInputExtraContextItem.tsx
tools/server/webui/src/components/ChatScreen.tsx
tools/server/webui/src/components/useChatExtraContext.tsx
tools/server/webui/src/utils/misc.ts
tools/server/webui/src/utils/types.ts

index 5254b2821e504546d9533c461bd4b5be5f1737d4..b79094c0a48b61a95fea8f633261300df213e22d 100644 (file)
@@ -12,17 +12,7 @@ size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
     size_t n_tokens = 0;
     for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
         auto chunk = mtmd_input_chunks_get(chunks, i);
-        auto chunk_type = mtmd_input_chunk_get_type(chunk);
-        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-            size_t n_tokens_text;
-            mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
-            n_tokens += n_tokens_text;
-        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
-            n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
-        } else {
-            GGML_ASSERT(false && "chunk type not supported");
-        }
+        n_tokens += mtmd_input_chunk_get_n_tokens(chunk);
     }
     return n_tokens;
 }
@@ -31,17 +21,7 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
     llama_pos n_pos = 0;
     for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
         auto chunk = mtmd_input_chunks_get(chunks, i);
-        auto chunk_type = mtmd_input_chunk_get_type(chunk);
-        if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-            size_t n_tokens_text;
-            mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
-            n_pos += n_tokens_text;
-        } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
-            auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
-            n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
-        } else {
-            GGML_ASSERT(false && "chunk type not supported");
-        }
+        n_pos += mtmd_input_chunk_get_n_pos(chunk);
     }
     return n_pos;
 }
index 344fe0b07dcf7c9155e8d508f31a0bb22f37ed60..d3f3cf3a061de373653a008fb1090e70ae01cd3d 100644 (file)
@@ -751,6 +751,10 @@ const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
     return bitmap->data.data();
 }
 
+size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap) {
+    return bitmap->data.size();
+}
+
 bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
     return bitmap->is_audio;
 }
index 0f4f9c62b7e9716f308d44e63dc54ec6039ce332..2c722b012ea053d2daff4d9fd65e23e0e04ad375 100644 (file)
@@ -119,11 +119,12 @@ MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
 //     the data is in float format (PCM F32)
 MTMD_API mtmd_bitmap *         mtmd_bitmap_init           (uint32_t nx, uint32_t ny, const unsigned char * data);
 MTMD_API mtmd_bitmap *         mtmd_bitmap_init_from_audio(size_t n_samples,         const float         * data);
-MTMD_API uint32_t              mtmd_bitmap_get_nx  (const mtmd_bitmap * bitmap);
-MTMD_API uint32_t              mtmd_bitmap_get_ny  (const mtmd_bitmap * bitmap);
-MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
-MTMD_API bool                  mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap);
-MTMD_API void                  mtmd_bitmap_free    (mtmd_bitmap * bitmap);
+MTMD_API uint32_t              mtmd_bitmap_get_nx     (const mtmd_bitmap * bitmap);
+MTMD_API uint32_t              mtmd_bitmap_get_ny     (const mtmd_bitmap * bitmap);
+MTMD_API const unsigned char * mtmd_bitmap_get_data   (const mtmd_bitmap * bitmap);
+MTMD_API size_t                mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
+MTMD_API bool                  mtmd_bitmap_is_audio   (const mtmd_bitmap * bitmap);
+MTMD_API void                  mtmd_bitmap_free       (mtmd_bitmap * bitmap);
 // bitmap ID is optional, but useful for KV cache tracking
 // these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
 MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
@@ -322,6 +323,7 @@ struct bitmap {
     uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
     uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
     const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
+    size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); }
     std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
     void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
 };
index 02fb00339ec8d2c81e56f396549dc7455f42d734..3f1d3f31dcbf924b0e18717ac52a14d9d108d0ed 100644 (file)
Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ
index 1a08e30d28751df7f66d4916613e6fc3bb8d329e..01afeafa0ff57bd3615ebd984d38fa92ec09b31d 100644 (file)
@@ -1891,6 +1891,7 @@ struct server_context {
     float slot_prompt_similarity = 0.0f;
 
     common_chat_templates_ptr chat_templates;
+    oaicompat_parser_options  oai_parser_opt;
 
     ~server_context() {
         mtmd_free(mctx);
@@ -2086,6 +2087,15 @@ struct server_context {
         }
 
         metrics.init();
+
+        oai_parser_opt = {
+            /* use_jinja             */ params_base.use_jinja,
+            /* prefill_assistant     */ params_base.prefill_assistant,
+            /* reasoning_format      */ params_base.reasoning_format,
+            /* common_chat_templates */ chat_templates.get(),
+            /* allow_image           */ mctx ? mtmd_support_vision(mctx) : false,
+            /* allow_audio           */ mctx ? mtmd_support_audio (mctx) : false,
+        };
     }
 
     server_slot * get_slot_by_id(int id) {
@@ -4092,7 +4102,10 @@ int main(int argc, char ** argv) {
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params_base.n_parallel },
             { "model_path",                  ctx_server.params_base.model.path },
-            { "modalities",                  json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future
+            { "modalities",                  json{
+                {"vision", ctx_server.oai_parser_opt.allow_image},
+                {"audio",  ctx_server.oai_parser_opt.allow_audio},
+            } },
             { "chat_template",               common_chat_templates_source(ctx_server.chat_templates.get()) },
             { "bos_token",                   common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
             { "eos_token",                   common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
@@ -4183,10 +4196,10 @@ int main(int argc, char ** argv) {
                 for (auto & file : files) {
                     mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
                     if (!bmp.ptr) {
-                        throw std::runtime_error("Failed to load image");
+                        throw std::runtime_error("Failed to load image or audio file");
                     }
                     // calculate bitmap hash (for KV caching)
-                    std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
+                    std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
                     bmp.set_id(hash.c_str());
                     bitmaps.entries.push_back(std::move(bmp));
                 }
@@ -4418,7 +4431,7 @@ int main(int argc, char ** argv) {
             OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
     };
 
-    const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
         LOG_DBG("request: %s\n", req.body.c_str());
         if (ctx_server.params_base.embedding) {
             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
@@ -4427,13 +4440,9 @@ int main(int argc, char ** argv) {
 
         auto body = json::parse(req.body);
         std::vector<raw_buffer> files;
-        json data = oaicompat_completion_params_parse(
+        json data = oaicompat_chat_params_parse(
             body,
-            params.use_jinja,
-            params.prefill_assistant,
-            params.reasoning_format,
-            ctx_server.chat_templates.get(),
-            ctx_server.mctx,
+            ctx_server.oai_parser_opt,
             files);
 
         handle_completions_impl(
@@ -4446,16 +4455,12 @@ int main(int argc, char ** argv) {
     };
 
     // same with handle_chat_completions, but without inference part
-    const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
+    const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
         auto body = json::parse(req.body);
         std::vector<raw_buffer> files; // dummy, unused
-        json data = oaicompat_completion_params_parse(
+        json data = oaicompat_chat_params_parse(
             body,
-            params.use_jinja,
-            params.prefill_assistant,
-            params.reasoning_format,
-            ctx_server.chat_templates.get(),
-            ctx_server.mctx,
+            ctx_server.oai_parser_opt,
             files);
         res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
     };
index 7cc4096f19e0ce5b6cf8b5cbf01b6154469378cb..fc63caa1342939578dadcc5aa1100d41c6754df5 100644 (file)
@@ -30,6 +30,7 @@ def create_server():
         ("What is this:\n", "malformed",              False, None),
         ("What is this:\n", "https://google.com/404", False, None), # non-existent image
         ("What is this:\n", "https://ggml.ai",        False, None), # non-image data
+        # TODO @ngxson : test with multiple images, no images and with audio
     ]
 )
 def test_vision_chat_completion(prompt, image_url, success, re_content):
index 9c41f9db5ff68881959f67c62edee73343f201cb..bb27b366ea2d673b4833ea958f33db13824c3bc1 100644 (file)
@@ -536,6 +536,7 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
 // OAI utils
 //
 
+// used by /completions endpoint
 static json oaicompat_completion_params_parse(const json & body) {
     json llama_params;
 
@@ -580,13 +581,19 @@ static json oaicompat_completion_params_parse(const json & body) {
     return llama_params;
 }
 
-static json oaicompat_completion_params_parse(
+struct oaicompat_parser_options {
+    bool use_jinja;
+    bool prefill_assistant;
+    common_reasoning_format reasoning_format;
+    common_chat_templates * tmpls;
+    bool allow_image;
+    bool allow_audio;
+};
+
+// used by /chat/completions endpoint
+static json oaicompat_chat_params_parse(
     const json & body, /* openai api json semantics */
-    bool use_jinja,
-    bool prefill_assistant,
-    common_reasoning_format reasoning_format,
-    const struct common_chat_templates * tmpls,
-    bool allow_non_text,
+    const oaicompat_parser_options & opt,
     std::vector<raw_buffer> & out_files)
 {
     json llama_params;
@@ -598,11 +605,11 @@ static json oaicompat_completion_params_parse(
         if (stream) {
             throw std::runtime_error("Cannot use tools with stream");
         }
-        if (!use_jinja) {
+        if (!opt.use_jinja) {
             throw std::runtime_error("tools param requires --jinja flag");
         }
     }
-    if (!use_jinja) {
+    if (!opt.use_jinja) {
         if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
             throw std::runtime_error("Unsupported param: tool_choice");
         }
@@ -667,12 +674,12 @@ static json oaicompat_completion_params_parse(
 
         for (auto & p : content) {
             std::string type      = json_value(p, "type", std::string());
-            json        image_url = json_value(p, "image_url", json::object());
             if (type == "image_url") {
-                if (!allow_non_text) {
-                    throw std::runtime_error("image input is not supported by this server");
+                if (!opt.allow_image) {
+                    throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
                 }
 
+                json image_url  = json_value(p, "image_url", json::object());
                 std::string url = json_value(image_url, "url", std::string());
                 if (string_starts_with(url, "http")) {
                     // download remote image
@@ -712,6 +719,29 @@ static json oaicompat_completion_params_parse(
                 p["type"] = "text";
                 p["text"] = mtmd_default_marker();
                 p.erase("image_url");
+
+            } else if (type == "input_audio") {
+                if (!opt.allow_audio) {
+                    throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
+                }
+
+                json input_audio   = json_value(p, "input_audio", json::object());
+                std::string data   = json_value(input_audio, "data", std::string());
+                std::string format = json_value(input_audio, "format", std::string());
+                // while we also support flac, we don't allow it here so we matches the OAI spec
+                if (format != "wav" && format != "mp3") {
+                    throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
+                }
+                auto decoded_data = base64_decode(data); // expected to be base64 encoded
+                out_files.push_back(decoded_data);
+
+                // replace this chunk with a marker
+                p["type"] = "text";
+                p["text"] = mtmd_default_marker();
+                p.erase("input_audio");
+
+            } else if (type != "text") {
+                throw std::runtime_error("unsupported content[].type");
             }
         }
     }
@@ -723,9 +753,9 @@ static json oaicompat_completion_params_parse(
     inputs.json_schema           = json_schema.is_null() ? "" : json_schema.dump();
     inputs.grammar               = grammar;
     inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
-    inputs.use_jinja             = use_jinja;
+    inputs.use_jinja             = opt.use_jinja;
     inputs.parallel_tool_calls   = json_value(body, "parallel_tool_calls", false);
-    inputs.extract_reasoning     = reasoning_format != COMMON_REASONING_FORMAT_NONE;
+    inputs.extract_reasoning     = opt.reasoning_format != COMMON_REASONING_FORMAT_NONE;
     inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
     if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
         throw std::runtime_error("Cannot use custom grammar constraints with tools.");
@@ -733,7 +763,7 @@ static json oaicompat_completion_params_parse(
 
     // if the assistant message appears at the end of list, we do not add end-of-turn token
     // for ex. this can be useful to modify the reasoning process in reasoning models
-    bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && prefill_assistant;
+    bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant;
     common_chat_msg last_message;
     if (prefill_assistant_message) {
         last_message = inputs.messages.back();
@@ -749,7 +779,7 @@ static json oaicompat_completion_params_parse(
     }
 
     // Apply chat template to the list of messages
-    auto chat_params = common_chat_templates_apply(tmpls, inputs);
+    auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
 
     /* Append assistant prefilled message */
     if (prefill_assistant_message) {
@@ -1040,7 +1070,7 @@ struct server_tokens {
 private: // disallow accessing these members directly, risking out-of-sync
 
     // map a **start** position in tokens to the image chunk
-    std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_image;
+    std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
 
     // list of tokens
     // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
@@ -1051,7 +1081,7 @@ private: // disallow accessing these members directly, risking out-of-sync
     // for ex. with input of 5 text tokens and 2 images:
     //      [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
     // pos  0   1   2   3   4   5      6      7      8      9
-    // map_pos_to_image will contain: {5, img0}, {8, img1}
+    // map_pos_to_media will contain: {5, img0}, {8, img1}
 
 public:
     server_tokens() = default;
@@ -1090,15 +1120,15 @@ public:
         }
         oss << "\n";
         oss << "image pos: ";
-        for (const auto & it : map_pos_to_image) {
+        for (const auto & it : map_pos_to_media) {
             oss << it.first << ", ";
         }
         return oss.str();
     }
 
     const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
-        auto it = map_pos_to_image.find(pos);
-        if (it != map_pos_to_image.end()) {
+        auto it = map_pos_to_media.find(pos);
+        if (it != map_pos_to_media.end()) {
             return it->second;
         } else {
             throw std::runtime_error("Chunk not found");
@@ -1115,16 +1145,15 @@ public:
     // will create a copy of the chunk if it contains non-text data
     void push_back(const mtmd_input_chunk * chunk) {
         auto type = mtmd_input_chunk_get_type(chunk);
-        if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+        if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
             GGML_ASSERT(has_mtmd);
-            auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk);
-            const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
+            const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
             llama_pos start_pos = tokens.size();
             for (int i = 0; i < n_pos; ++i) {
                 tokens.emplace_back(LLAMA_TOKEN_NULL);
             }
             mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
-            map_pos_to_image[start_pos] = std::move(new_chunk);
+            map_pos_to_media[start_pos] = std::move(new_chunk);
         } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
             size_t n_tokens;
             auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
@@ -1169,6 +1198,9 @@ public:
     void keep_first(size_t n) {
         GGML_ASSERT(n <= tokens.size());
         if (has_mtmd) {
+            if (n == tokens.size()) {
+                return; // nothing to do
+            }
             // we throw an error if we try to remove a token in the middle of an image
             // for ex. with input of 5 text tokens and 2 images:
             //    [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
@@ -1183,10 +1215,10 @@ public:
                 }
             }
             // remove all image chunks that are not used anymore
-            for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) {
+            for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
                 llama_pos pos = it->first;
                 if (pos >= (llama_pos)n) {
-                    it = map_pos_to_image.erase(it);
+                    it = map_pos_to_media.erase(it);
                 } else {
                     ++it;
                 }
@@ -1217,14 +1249,12 @@ public:
                 const auto & a_chunk =   find_chunk(i);
                 const auto & b_chunk = b.find_chunk(i);
                 GGML_ASSERT(a_chunk && b_chunk);
-                const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get());
-                const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get());
-                std::string ai_id  = mtmd_image_tokens_get_id(a_img);
-                std::string bi_id  = mtmd_image_tokens_get_id(b_img);
-                size_t a_pos       = mtmd_image_tokens_get_n_pos(a_img);
-                size_t b_pos       = mtmd_image_tokens_get_n_pos(b_img);
+                std::string ai_id  = mtmd_input_chunk_get_id(a_chunk.get());
+                std::string bi_id  = mtmd_input_chunk_get_id(b_chunk.get());
+                size_t a_pos       = mtmd_input_chunk_get_n_pos(a_chunk.get());
+                size_t b_pos       = mtmd_input_chunk_get_n_pos(b_chunk.get());
                 if (ai_id == bi_id && a_pos == b_pos) {
-                    GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
+                    GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
                     i += a_pos - 1; // will be +1 by the for loop
                     continue;
                 } else {
@@ -1250,8 +1280,7 @@ public:
             if (t == LLAMA_TOKEN_NULL) {
                 try {
                     const auto & chunk = find_chunk(i);
-                    const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
-                    size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
+                    size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
                     i += n_pos - 1; // will be +1 by the for loop
                 } catch (const std::exception & e) {
                     return false;
@@ -1270,22 +1299,21 @@ public:
                 llama_pos n_past,
                 int32_t seq_id,
                 llama_pos & n_pos_out) {
-        auto it = map_pos_to_image.find(n_past);
-        if (it == map_pos_to_image.end()) {
-            throw std::runtime_error("Chunk not found");
-        }
-        SRV_INF("%s\n", "processing image...");
+        auto & chunk = find_chunk(n_past);
+        const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
+                            ? "image" : "audio";
+        SRV_INF("processing %s...\n", name);
         int32_t n_batch = llama_n_batch(ctx);
         int64_t t0 = ggml_time_ms();
         llama_pos new_n_past = n_past;
         int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
-            it->second.get(), // chunk
+            chunk.get(),
             n_past,
             seq_id,
             n_batch,
             true, // logits last
             &new_n_past);
-        SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
+        SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
         if (result != 0) {
             LOG_ERR("mtmd_helper_eval failed with status %d", result);
             n_pos_out = n_past;
index 4f28f887482a631dc77bc7d38325fe876a554a6f..2d4179ea4703ed2c3a8a795e9d03ddf5154da69c 100644 (file)
@@ -1,4 +1,8 @@
-import { DocumentTextIcon, XMarkIcon } from '@heroicons/react/24/outline';
+import {
+  DocumentTextIcon,
+  SpeakerWaveIcon,
+  XMarkIcon,
+} from '@heroicons/react/24/outline';
 import { MessageExtra } from '../utils/types';
 import { useState } from 'react';
 import { classNames } from '../utils/misc';
@@ -66,7 +70,11 @@ export default function ChatInputExtraContextItem({
                   className="w-14 h-14 flex items-center justify-center"
                   aria-description="Document icon"
                 >
-                  <DocumentTextIcon className="h-8 w-14 text-base-content/50" />
+                  {item.type === 'audioFile' ? (
+                    <SpeakerWaveIcon className="h-8 w-8 text-gray-500" />
+                  ) : (
+                    <DocumentTextIcon className="h-8 w-8 text-gray-500" />
+                  )}
                 </div>
 
                 <div className="text-xs pr-4">
@@ -98,6 +106,19 @@ export default function ChatInputExtraContextItem({
                 src={showingItem.base64Url}
                 alt={`Preview image for ${showingItem.name}`}
               />
+            ) : showingItem.type === 'audioFile' ? (
+              <audio
+                controls
+                className="w-full"
+                aria-description={`Audio file ${showingItem.name}`}
+              >
+                <source
+                  src={`data:${showingItem.mimeType};base64,${showingItem.base64Data}`}
+                  type={showingItem.mimeType}
+                  aria-description={`Audio file ${showingItem.name}`}
+                />
+                Your browser does not support the audio element.
+              </audio>
             ) : (
               <div className="overflow-x-auto">
                 <pre className="whitespace-pre-wrap break-words text-sm">
index 09c601ef2366aab6099fd685d3af659b9333a549..c1a669144550797f1e2bd60de074b471da134a3e 100644 (file)
@@ -278,6 +278,13 @@ export default function ChatScreen() {
 
 function ServerInfo() {
   const { serverProps } = useAppContext();
+  const modalities = [];
+  if (serverProps?.modalities?.audio) {
+    modalities.push('audio');
+  }
+  if (serverProps?.modalities?.vision) {
+    modalities.push('vision');
+  }
   return (
     <div
       className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"
@@ -291,6 +298,13 @@ function ServerInfo() {
           <br />
           <b>Build</b>: {serverProps?.build_info}
           <br />
+          {modalities.length > 0 ? (
+            <>
+              <b>Supported modalities:</b> {modalities.join(', ')}
+            </>
+          ) : (
+            ''
+          )}
         </p>
       </div>
     </div>
index b9794405a5da578499c4ec213afb55649b2dc65b..42765524067e22e1d3110962e46e7f65516e7b0f 100644 (file)
@@ -11,6 +11,7 @@ pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
 // This file handles uploading extra context items (a.k.a files)
 // It allows processing these kinds of files:
 // - image files (converted to base64)
+// - audio files (converted to base64)
 // - text files (including code files)
 // - pdf (converted to text)
 
@@ -41,96 +42,73 @@ export function useChatExtraContext(): ChatExtraContextApi {
 
   const isSupportVision = serverProps?.modalities?.vision;
 
-  const onFileAdded = (files: File[]) => {
-    for (const file of files) {
-      const mimeType = file.type;
-      console.debug({ mimeType, file });
-      if (file.size > 10 * 1024 * 1024) {
-        toast.error('File is too large. Maximum size is 10MB.');
-        break;
-      }
-
-      if (mimeType.startsWith('image/')) {
-        if (!isSupportVision) {
-          toast.error('Multimodal is not supported by this server or model.');
+  const onFileAdded = async (files: File[]) => {
+    try {
+      for (const file of files) {
+        const mimeType = file.type;
+        if (file.size > 10 * 1024 * 1024) {
+          toast.error('File is too large. Maximum size is 10MB.');
           break;
         }
-        const reader = new FileReader();
-        reader.onload = async (event) => {
-          if (event.target?.result) {
-            let base64Url = event.target.result as string;
-
-            if (mimeType === 'image/svg+xml') {
-              // Convert SVG to PNG
-              base64Url = await svgBase64UrlToPngDataURL(base64Url);
-            }
 
-            addItems([
-              {
-                type: 'imageFile',
-                name: file.name,
-                base64Url,
-              },
-            ]);
+        if (mimeType.startsWith('image/')) {
+          if (!isSupportVision) {
+            toast.error('Multimodal is not supported by this server or model.');
+            break;
           }
-        };
-        reader.readAsDataURL(file);
-      } else if (
-        mimeType.startsWith('video/') ||
-        mimeType.startsWith('audio/')
-      ) {
-        toast.error('Video and audio files are not supported yet.');
-        break;
-      } else if (mimeType.startsWith('application/pdf')) {
-        if (config.pdfAsImage && !isSupportVision) {
-          toast(
-            'Multimodal is not supported, PDF will be converted to text instead of image.'
-          );
+
+          let base64Url = await getFileAsBase64(file);
+          if (mimeType === 'image/svg+xml') {
+            // Convert SVG to PNG
+            base64Url = await svgBase64UrlToPngDataURL(base64Url);
+          }
+          addItems([
+            {
+              type: 'imageFile',
+              name: file.name,
+              base64Url,
+            },
+          ]);
+        } else if (mimeType.startsWith('video/')) {
+          toast.error('Video files are not supported yet.');
           break;
-        }
+        } else if (mimeType.startsWith('audio/')) {
+          if (!/mpeg|wav/.test(mimeType)) {
+            toast.error('Only mp3 and wav audio files are supported.');
+            break;
+          }
 
-        const promise =
-          config.pdfAsImage && isSupportVision
-            ? convertPDFToImage(file).then((base64Urls) => {
-                addItems(
-                  base64Urls.map((base64Url) => ({
-                    type: 'imageFile',
-                    name: file.name,
-                    base64Url,
-                  }))
-                );
-              })
-            : convertPDFToText(file).then((content) => {
-                if (isSupportVision) {
-                  toast.success(
-                    'PDF file converted to text. You can also convert it to image, see in Settings.'
-                  );
-                }
-                addItems([
-                  {
-                    type: 'textFile',
-                    name: file.name,
-                    content,
-                  },
-                ]);
-              });
-
-        promise.catch((error) => {
-          console.error(error);
-          toast.error('Failed to parse PDF file.');
-        });
-        break;
-      } else {
-        // Because there can be many text file types (like code file), we will not check the mime type
-        // and will just check if the file is not binary.
-        const reader = new FileReader();
-        reader.onload = (event) => {
-          if (event.target?.result) {
-            const content = event.target.result as string;
-            if (!isLikelyNotBinary(content)) {
-              toast.error('File is binary. Please upload a text file.');
-              return;
-            }
+          // plain base64, not a data URL
+          const base64Data = await getFileAsBase64(file, false);
+          addItems([
+            {
+              type: 'audioFile',
+              name: file.name,
+              mimeType,
+              base64Data,
+            },
+          ]);
+        } else if (mimeType.startsWith('application/pdf')) {
+          if (config.pdfAsImage && !isSupportVision) {
+            toast(
+              'Multimodal is not supported, PDF will be converted to text instead of image.'
+            );
+            break;
+          }
+
+          if (config.pdfAsImage && isSupportVision) {
+            // Convert PDF to images
+            const base64Urls = await convertPDFToImage(file);
+            addItems(
+              base64Urls.map((base64Url) => ({
+                type: 'imageFile',
+                name: file.name,
+                base64Url,
+              }))
+            );
+          } else {
+            // Convert PDF to text
+            const content = await convertPDFToText(file);
             addItems([
               {
                 type: 'textFile',
@@ -138,10 +116,40 @@ export function useChatExtraContext(): ChatExtraContextApi {
                 content,
               },
             ]);
+            if (isSupportVision) {
+              toast.success(
+                'PDF file converted to text. You can also convert it to image, see in Settings.'
+              );
+            }
           }
-        };
-        reader.readAsText(file);
+          break;
+        } else {
+          // Because there can be many text file types (like code file), we will not check the mime type
+          // and will just check if the file is not binary.
+          const reader = new FileReader();
+          reader.onload = (event) => {
+            if (event.target?.result) {
+              const content = event.target.result as string;
+              if (!isLikelyNotBinary(content)) {
+                toast.error('File is binary. Please upload a text file.');
+                return;
+              }
+              addItems([
+                {
+                  type: 'textFile',
+                  name: file.name,
+                  content,
+                },
+              ]);
+            }
+          };
+          reader.readAsText(file);
+        }
       }
+    } catch (error) {
+      const message = error instanceof Error ? error.message : String(error);
+      const errorMessage = `Error processing file: ${message}`;
+      toast.error(errorMessage);
     }
   };
 
@@ -154,6 +162,25 @@ export function useChatExtraContext(): ChatExtraContextApi {
   };
 }
 
+async function getFileAsBase64(file: File, outputUrl = true): Promise<string> {
+  return new Promise((resolve, reject) => {
+    const reader = new FileReader();
+    reader.onload = (event) => {
+      if (event.target?.result) {
+        let result = event.target.result as string;
+        if (!outputUrl) {
+          // remove base64 url prefix and correct characters
+          result = result.substring(result.indexOf(',') + 1);
+        }
+        resolve(result);
+      } else {
+        reject(new Error('Failed to read file.'));
+      }
+    };
+    reader.readAsDataURL(file);
+  });
+}
+
 async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
   return new Promise((resolve, reject) => {
     const reader = new FileReader();
index ba760e83bb2822f682991754c91ecd34d5cf3371..d60a68cd2431bfbb3b73bd7012bee8a0179f5e34 100644 (file)
@@ -89,6 +89,14 @@ export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
           type: 'image_url',
           image_url: { url: extra.base64Url },
         });
+      } else if (extra.type === 'audioFile') {
+        contentArr.push({
+          type: 'input_audio',
+          input_audio: {
+            data: extra.base64Data,
+            format: /wav/.test(extra.mimeType) ? 'wav' : 'mp3',
+          },
+        });
       } else {
         throw new Error('Unknown extra type');
       }
index ba673dd9432da4aaa76201d2234d8140a7fa9cc5..ea7d641dc748bb8bb25c1b356c6ab9aa9ed4b4f8 100644 (file)
@@ -51,6 +51,7 @@ export interface Message {
 export type MessageExtra =
   | MessageExtraTextFile
   | MessageExtraImageFile
+  | MessageExtraAudioFile
   | MessageExtraContext;
 
 export interface MessageExtraTextFile {
@@ -65,6 +66,13 @@ export interface MessageExtraImageFile {
   base64Url: string;
 }
 
+export interface MessageExtraAudioFile {
+  type: 'audioFile';
+  name: string;
+  base64Data: string;
+  mimeType: string;
+}
+
 export interface MessageExtraContext {
   type: 'context';
   name: string;
@@ -79,6 +87,10 @@ export type APIMessageContentPart =
   | {
       type: 'image_url';
       image_url: { url: string };
+    }
+  | {
+      type: 'input_audio';
+      input_audio: { data: string; format: 'wav' | 'mp3' };
     };
 
 export type APIMessage = {
@@ -120,6 +132,7 @@ export interface LlamaCppServerProps {
   n_ctx: number;
   modalities?: {
     vision: boolean;
+    audio: boolean;
   };
   // TODO: support params
 }