]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: string_split fix (#10022)
authorMichael Podvitskiy <redacted>
Fri, 25 Oct 2024 15:57:54 +0000 (17:57 +0200)
committerGitHub <redacted>
Fri, 25 Oct 2024 15:57:54 +0000 (17:57 +0200)
* llama: Refactor string_split to use template specialization,  fixes parsing strings with spaces

* llama: Add static_assert in the string_split template to ensure the correct template specialization is used for std::string

common/arg.cpp
common/common.cpp
common/common.h
examples/server/server.cpp

index cd9d315dc78ffb3b6c9968124e8b6c316b534685..608e46e0202b385312d88964b15e8c30f6fc46c2 100644 (file)
@@ -128,13 +128,13 @@ static void common_params_handle_model_default(common_params & params) {
             }
             params.hf_file = params.model;
         } else if (params.model.empty()) {
-            params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
+            params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
         }
     } else if (!params.model_url.empty()) {
         if (params.model.empty()) {
-            auto f = string_split(params.model_url, '#').front();
-            f = string_split(f, '?').front();
-            params.model = fs_get_cache_file(string_split(f, '/').back());
+            auto f = string_split<std::string>(params.model_url, '#').front();
+            f = string_split<std::string>(f, '?').front();
+            params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
         }
     } else if (params.model.empty()) {
         params.model = DEFAULT_MODEL_PATH;
@@ -879,7 +879,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"--samplers"}, "SAMPLERS",
         string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
         [](common_params & params, const std::string & value) {
-            const auto sampler_names = string_split(value, ';');
+            const auto sampler_names = string_split<std::string>(value, ';');
             params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
         }
     ).set_sparam());
index a8eebb68b5351becd1f144c6662ea9457c213007..faaa420d9ce85ca17e3fbc5817a93039220496f3 100644 (file)
@@ -416,19 +416,6 @@ std::string string_format(const char * fmt, ...) {
     return std::string(buf.data(), size);
 }
 
-std::vector<std::string> string_split(std::string input, char separator) {
-    std::vector<std::string> parts;
-    size_t separator_pos = input.find(separator);
-    while (separator_pos != std::string::npos) {
-        std::string part = input.substr(0, separator_pos);
-        parts.emplace_back(part);
-        input = input.substr(separator_pos + 1);
-        separator_pos = input.find(separator);
-    }
-    parts.emplace_back(input);
-    return parts;
-}
-
 std::string string_strip(const std::string & str) {
     size_t start = 0;
     size_t end = str.size();
index 19d928777ccd5ab8c3ca50fe408fdaf507f1f9ca..f9333395c22086d404a3aebc2a91a4976f70ff98 100644 (file)
@@ -380,8 +380,6 @@ bool set_process_priority(enum ggml_sched_priority prio);
 LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
 std::string string_format(const char * fmt, ...);
 
-std::vector<std::string> string_split(std::string input, char separator);
-
 std::string string_strip(const std::string & str);
 std::string string_get_sortable_timestamp();
 
@@ -389,6 +387,7 @@ void string_replace_all(std::string & s, const std::string & search, const std::
 
 template<class T>
 static std::vector<T> string_split(const std::string & str, char delim) {
+    static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
     std::vector<T> values;
     std::istringstream str_stream(str);
     std::string token;
@@ -401,6 +400,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
     return values;
 }
 
+template<>
+std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
+{
+    std::vector<std::string> parts;
+    size_t begin_pos = 0;
+    size_t separator_pos = input.find(separator);
+    while (separator_pos != std::string::npos) {
+        std::string part = input.substr(begin_pos, separator_pos - begin_pos);
+        parts.emplace_back(part);
+        begin_pos = separator_pos + 1;
+        separator_pos = input.find(separator, begin_pos);
+    }
+    parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
+    return parts;
+}
+
 bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
 void string_process_escapes(std::string & input);
 
index 2821877b2a6fb51b15e4662ab920efbe6c25091b..3c12ef6f0f1196c111280c84deeec483f3f48b75 100644 (file)
@@ -2380,7 +2380,7 @@ int main(int argc, char ** argv) {
     auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
         server_state current_state = state.load();
         if (current_state == SERVER_STATE_LOADING_MODEL) {
-            auto tmp = string_split(req.path, '.');
+            auto tmp = string_split<std::string>(req.path, '.');
             if (req.path == "/" || tmp.back() == "html") {
                 res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
                 res.status = 503;