]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix build + rename enums (#4870)
authorGeorgi Gerganov <redacted>
Thu, 11 Jan 2024 07:10:34 +0000 (09:10 +0200)
committerGitHub <redacted>
Thu, 11 Jan 2024 07:10:34 +0000 (09:10 +0200)
examples/server/server.cpp

index 1cca634d5461f4e8870e4561951fa42ecc21cd5e..4a0714997f17d4d31a53eab7c6467fd1fa99b733 100644 (file)
@@ -147,15 +147,15 @@ static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
 // parallel
 //
 
-enum ServerState {
-    LOADING_MODEL,  // Server is starting up, model not fully loaded yet
-    READY,          // Server is ready and model is loaded
-    ERROR           // An error occurred, load_model failed
+enum server_state {
+    SERVER_STATE_LOADING_MODEL,  // Server is starting up, model not fully loaded yet
+    SERVER_STATE_READY,          // Server is ready and model is loaded
+    SERVER_STATE_ERROR           // An error occurred, load_model failed
 };
 
 enum task_type {
-    COMPLETION_TASK,
-    CANCEL_TASK
+    TASK_TYPE_COMPLETION,
+    TASK_TYPE_CANCEL,
 };
 
 struct task_server {
@@ -1402,7 +1402,7 @@ struct llama_server_context
         task.data = std::move(data);
         task.infill_mode = infill;
         task.embedding_mode = embedding;
-        task.type = COMPLETION_TASK;
+        task.type = TASK_TYPE_COMPLETION;
         task.multitask_id = multitask_id;
 
         // when a completion task's prompt array is not a singleton, we split it into multiple requests
@@ -1524,7 +1524,7 @@ struct llama_server_context
         std::unique_lock<std::mutex> lock(mutex_tasks);
         task_server task;
         task.id = id_gen++;
-        task.type = CANCEL_TASK;
+        task.type = TASK_TYPE_CANCEL;
         task.target_id = task_id;
         queue_tasks.push_back(task);
         condition_tasks.notify_one();
@@ -1560,7 +1560,7 @@ struct llama_server_context
             queue_tasks.erase(queue_tasks.begin());
             switch (task.type)
             {
-                case COMPLETION_TASK: {
+                case TASK_TYPE_COMPLETION: {
                     llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
                     if (slot == nullptr)
                     {
@@ -1589,7 +1589,7 @@ struct llama_server_context
                         break;
                     }
                 } break;
-                case CANCEL_TASK: { // release slot linked with the task id
+                case TASK_TYPE_CANCEL: { // release slot linked with the task id
                     for (auto & slot : slots)
                     {
                         if (slot.task_id == task.target_id)
@@ -2798,24 +2798,24 @@ int main(int argc, char **argv)
 
     httplib::Server svr;
 
-    std::atomic<ServerState> server_state{LOADING_MODEL};
+    std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
 
     svr.set_default_headers({{"Server", "llama.cpp"},
                              {"Access-Control-Allow-Origin", "*"},
                              {"Access-Control-Allow-Headers", "content-type"}});
 
     svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
-        ServerState current_state = server_state.load();
+        server_state current_state = state.load();
         switch(current_state) {
-            case READY:
+            case SERVER_STATE_READY:
                 res.set_content(R"({"status": "ok"})", "application/json");
                 res.status = 200; // HTTP OK
                 break;
-            case LOADING_MODEL:
+            case SERVER_STATE_LOADING_MODEL:
                 res.set_content(R"({"status": "loading model"})", "application/json");
                 res.status = 503; // HTTP Service Unavailable
                 break;
-            case ERROR:
+            case SERVER_STATE_ERROR:
                 res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
                 res.status = 500; // HTTP Internal Server Error
                 break;
@@ -2891,7 +2891,7 @@ int main(int argc, char **argv)
             {
                 if (!svr.listen_after_bind())
                 {
-                    server_state.store(ERROR);
+                    state.store(SERVER_STATE_ERROR);
                     return 1;
                 }
 
@@ -2901,11 +2901,11 @@ int main(int argc, char **argv)
     // load the model
     if (!llama.load_model(params))
     {
-        server_state.store(ERROR);
+        state.store(SERVER_STATE_ERROR);
         return 1;
     } else {
         llama.initialize();
-        server_state.store(READY);
+        state.store(SERVER_STATE_READY);
     }
 
     // Middleware for API key validation