]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cli : add command and file auto-completion (#19985)
authorSigbjørn Skjæret <redacted>
Thu, 5 Mar 2026 09:47:28 +0000 (10:47 +0100)
committerGitHub <redacted>
Thu, 5 Mar 2026 09:47:28 +0000 (10:47 +0100)
common/console.cpp
common/console.h
tools/cli/cli.cpp

index 2ea178f81edd69702fc52bbff33a3a9dfe9d1816..a770416ab7a9f9b102cf881800adf2ef6df03030 100644 (file)
@@ -80,6 +80,8 @@ namespace console {
     static termios      initial_state;
 #endif
 
+    static completion_callback completion_cb = nullptr;
+
     //
     // Init and cleanup
     //
@@ -493,7 +495,7 @@ namespace console {
     }
 
     static void set_line_contents(std::string new_line, std::string & line, std::vector<int> & widths, size_t & char_pos,
-                                  size_t & byte_pos) {
+                                  size_t & byte_pos, int cursor_byte_pos = -1) {
         move_to_line_start(char_pos, byte_pos, widths);
         clear_current_line(widths);
 
@@ -503,6 +505,7 @@ namespace console {
         char_pos = 0;
 
         size_t idx = 0;
+        int back_width = 0;
         while (idx < line.size()) {
             size_t advance = 0;
             char32_t cp = decode_utf8(line, idx, advance);
@@ -511,8 +514,15 @@ namespace console {
             if (real_width < 0) real_width = 0;
             widths.push_back(real_width);
             idx += advance;
-            ++char_pos;
-            byte_pos = idx;
+            if (cursor_byte_pos >= 0 && static_cast<size_t>(cursor_byte_pos) < idx) {
+                back_width += real_width;
+            } else {
+                ++char_pos;
+                byte_pos = idx;
+            }
+        }
+        if (cursor_byte_pos >= 0) {
+            move_cursor(-back_width);
         }
     }
 
@@ -784,6 +794,20 @@ namespace console {
                 break;
             }
 
+            if (completion_cb && input_char == '\t') {
+                auto candidates = completion_cb(line, byte_pos);
+
+                if (!candidates.empty()) {
+                    if (candidates.size() > 1 || candidates[0].first != line) {
+                        // TODO?: Display all candidates
+                        set_line_contents(candidates[0].first, line, widths, char_pos, byte_pos, candidates[0].second);
+                    } else {
+                        // TODO: Move cursor to new byte_pos
+                    }
+                    continue;
+                }
+            }
+
             if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D */) {
                 end_of_stream = true;
                 break;
@@ -1062,6 +1086,10 @@ namespace console {
         return readline_advanced(line, multiline_input);
     }
 
+    void set_completion_callback(completion_callback cb) {
+        completion_cb = cb;
+    }
+
     namespace spinner {
         static const char LOADING_CHARS[] = {'|', '/', '-', '\\'};
         static std::condition_variable cv_stop;
index fad6d3953163f582a2334e52b47b475139e91fc3..72781bea6f607152bb08f3dc959a1d030c32c0c0 100644 (file)
@@ -4,7 +4,9 @@
 
 #include "common.h"
 
+#include <functional>
 #include <string>
+#include <vector>
 
 enum display_type {
     DISPLAY_TYPE_RESET = 0,
@@ -21,6 +23,9 @@ namespace console {
     void set_display(display_type display);
     bool readline(std::string & line, bool multiline_input);
 
+    using completion_callback = std::function<std::vector<std::pair<std::string, size_t>>(std::string_view, size_t)>;
+    void set_completion_callback(completion_callback cb);
+
     namespace spinner {
         void start();
         void stop();
index e57bf52e36c1f95933b9bb2ca838b448bf6745c1..65ff4ac6c09de8fe2275757319274b328cdde337 100644 (file)
@@ -6,7 +6,10 @@
 #include "server-context.h"
 #include "server-task.h"
 
+#include <array>
 #include <atomic>
+#include <algorithm>
+#include <filesystem>
 #include <fstream>
 #include <thread>
 #include <signal.h>
@@ -195,6 +198,122 @@ struct cli_context {
     }
 };
 
+// TODO?: Make this reusable, enums, docs
+static const std::array<const std::string, 6> cmds = {
+    "/audio ",
+    "/clear",
+    "/exit",
+    "/image ",
+    "/read ",
+    "/regen",
+};
+
+static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
+    std::vector<std::pair<std::string, size_t>> matches;
+    std::string cmd;
+
+    if (line.length() > 1 && line[0] == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](const std::string & prefix) {
+        return string_starts_with(line, prefix);
+    })) {
+        auto it = cmds.begin();
+
+        while ((it = std::find_if(it, cmds.end(), [line](const std::string & cmd_line) {
+            return string_starts_with(cmd_line, line);
+        })) != cmds.end()) {
+            matches.emplace_back(*it, (*it).length());
+            ++it;
+        }
+    } else {
+        auto it = std::find_if(cmds.begin(), cmds.end(), [line](const std::string & prefix) {
+            return prefix.back() == ' ' && string_starts_with(line, prefix);
+        });
+
+        if (it != cmds.end()) {
+            cmd = *it;
+        }
+    }
+
+    if (!cmd.empty() && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
+        const std::string path_prefix  = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
+        const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
+        auto cur_dir = std::filesystem::current_path();
+        std::string cur_dir_str = cur_dir.string();
+        std::string expanded_prefix = path_prefix;
+
+#if !defined(_WIN32)
+        if (string_starts_with(path_prefix, "~")) {
+            const char * home = std::getenv("HOME");
+            if (home && home[0]) {
+                expanded_prefix = std::string(home) + path_prefix.substr(1);
+            }
+        }
+        if (string_starts_with(expanded_prefix, "/")) {
+#else
+        if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
+#endif
+            cur_dir = std::filesystem::path(expanded_prefix).parent_path();
+            cur_dir_str = "";
+        } else if (!path_prefix.empty()) {
+            cur_dir /= std::filesystem::path(path_prefix).parent_path();
+        }
+
+        std::error_code ec;
+        for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
+            if (ec) {
+                break;
+            }
+            if (!entry.exists(ec)) {
+                ec.clear();
+                continue;
+            }
+
+            const std::string path_full = entry.path().string();
+            std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
+
+            if (entry.is_directory(ec)) {
+                path_entry.push_back(std::filesystem::path::preferred_separator);
+            }
+
+            if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
+                std::string updated_line = cmd + path_entry;
+                matches.emplace_back(updated_line + path_postfix, updated_line.length());
+            }
+
+            if (ec) {
+                ec.clear();
+            }
+        }
+
+        if (matches.empty()) {
+            std::string updated_line = cmd + path_prefix;
+            matches.emplace_back(updated_line + path_postfix, updated_line.length());
+        }
+
+        // Add the longest common prefix
+        if (!expanded_prefix.empty() && matches.size() > 1) {
+            const std::string_view match0(matches[0].first);
+            const std::string_view match1(matches[1].first);
+            auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
+            size_t len = it.first - match0.begin();
+
+            for (size_t i = 2; i < matches.size(); ++i) {
+                const std::string_view matchi(matches[i].first);
+                auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
+                len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
+            }
+
+            std::string updated_line = std::string(match0.substr(0, len));
+            matches.emplace_back(updated_line + path_postfix, updated_line.length());
+        }
+
+        std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
+            return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
+        });
+    }
+
+    return matches;
+}
+
 int main(int argc, char ** argv) {
     common_params params;
 
@@ -223,6 +342,7 @@ int main(int argc, char ** argv) {
     atexit([]() { console::cleanup(); });
 
     console::set_display(DISPLAY_TYPE_RESET);
+    console::set_completion_callback(auto_completion_callback);
 
 #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
     struct sigaction sigint_action;