]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
rpc : add RPC_CMD_HELLO (llama/12955)
authorRadoslav Gerganov <redacted>
Fri, 18 Apr 2025 07:13:42 +0000 (10:13 +0300)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
Add RPC_CMD_HELLO for getting the version of the protocol implemend by
the server. Follow the semantic versioning rules at https://semver.org

Hopefully this bring better user experience when we make breaking
changes at the protocol level and avoid issues like #12465

ggml/include/ggml-rpc.h
ggml/src/ggml-rpc/ggml-rpc.cpp

index 4e0d210f8ec973a3a60bc75e2d2b200bcf98d53d..c8b6097f7e5730026ccf05241d8c2cf5b494efa3 100644 (file)
@@ -7,6 +7,9 @@
 extern "C" {
 #endif
 
+#define RPC_PROTO_MAJOR_VERSION    1
+#define RPC_PROTO_MINOR_VERSION    0
+#define RPC_PROTO_PATCH_VERSION    0
 #define GGML_RPC_MAX_SERVERS       16
 
 // backend API
index 3189ae85d55f97bee6798d8f87972e461de8fe75..a0667b7d702b2ab7afa56f521c692a73e74b6a60 100644 (file)
@@ -92,12 +92,19 @@ enum rpc_cmd {
     RPC_CMD_GET_DEVICE_MEMORY,
     RPC_CMD_INIT_TENSOR,
     RPC_CMD_GET_ALLOC_SIZE,
+    RPC_CMD_HELLO,
     RPC_CMD_COUNT,
 };
 
 // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
 const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
 
+struct rpc_msg_hello_rsp {
+    uint8_t major;
+    uint8_t minor;
+    uint8_t patch;
+};
+
 struct rpc_msg_get_alloc_size_req {
     rpc_tensor tensor;
 };
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
 
 // RPC client-side implementation
 
+static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
+    rpc_msg_hello_rsp response;
+    bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
+    GGML_ASSERT(status);
+    if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
+        fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
+        return false;
+    }
+    if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
+        fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
+    }
+    return true;
+}
+
 static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
     static std::mutex mutex;
     std::lock_guard<std::mutex> lock(mutex);
@@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
     if (sock == nullptr) {
         return nullptr;
     }
+    if (!check_server_version(sock)) {
+        return nullptr;
+    }
     GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
     sockets[endpoint] = sock;
     return sock;
@@ -818,6 +842,7 @@ public:
     }
     ~rpc_server();
 
+    void hello(rpc_msg_hello_rsp & response);
     void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
     void get_alignment(rpc_msg_get_alignment_rsp & response);
     void get_max_size(rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ private:
     std::unordered_set<ggml_backend_buffer_t> buffers;
 };
 
+void rpc_server::hello(rpc_msg_hello_rsp & response) {
+    response.major = RPC_PROTO_MAJOR_VERSION;
+    response.minor = RPC_PROTO_MINOR_VERSION;
+    response.patch = RPC_PROTO_PATCH_VERSION;
+    GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
+}
+
 bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
     ggml_backend_buffer_type_t buft;
     struct ggml_init_params params {
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
 static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
                              sockfd_t sockfd, size_t free_mem, size_t total_mem) {
     rpc_server server(backend, cache_dir);
+    uint8_t cmd;
+    if (!recv_data(sockfd, &cmd, 1)) {
+        return;
+    }
+    // the first command sent by the client must be HELLO
+    if (cmd != RPC_CMD_HELLO) {
+        fprintf(stderr, "Expected HELLO command, update client\n");
+        return;
+    }
+    if (!recv_msg(sockfd, nullptr, 0)) {
+        return;
+    }
+    rpc_msg_hello_rsp response;
+    server.hello(response);
+    if (!send_msg(sockfd, &response, sizeof(response))) {
+        return;
+    }
     while (true) {
-        uint8_t cmd;
         if (!recv_data(sockfd, &cmd, 1)) {
             break;
         }
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
             break;
         }
         switch (cmd) {
+            case RPC_CMD_HELLO: {
+                // HELLO command is handled above
+                return;
+            }
             case RPC_CMD_ALLOC_BUFFER: {
                 rpc_msg_alloc_buffer_req request;
                 if (!recv_msg(sockfd, &request, sizeof(request))) {