]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
rpc : prevent crashes on invalid input (llama/9040)
authorRadoslav Gerganov <redacted>
Mon, 19 Aug 2024 07:10:21 +0000 (10:10 +0300)
committerGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:01:14 +0000 (22:01 +0300)
Add more checks which prevent RPC server from crashing if invalid input
is received from client

src/ggml-rpc.cpp

index 7757615f5a24bdb6c0b07eef72bfec5c5198a677..3e50a0091aa1706922c5922fb84bc87f8300d4f2 100644 (file)
@@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
 
 // RPC commands
 enum rpc_cmd {
-    ALLOC_BUFFER = 0,
-    GET_ALIGNMENT,
-    GET_MAX_SIZE,
-    BUFFER_GET_BASE,
-    FREE_BUFFER,
-    BUFFER_CLEAR,
-    SET_TENSOR,
-    GET_TENSOR,
-    COPY_TENSOR,
-    GRAPH_COMPUTE,
-    GET_DEVICE_MEMORY,
+    RPC_CMD_ALLOC_BUFFER = 0,
+    RPC_CMD_GET_ALIGNMENT,
+    RPC_CMD_GET_MAX_SIZE,
+    RPC_CMD_BUFFER_GET_BASE,
+    RPC_CMD_FREE_BUFFER,
+    RPC_CMD_BUFFER_CLEAR,
+    RPC_CMD_SET_TENSOR,
+    RPC_CMD_GET_TENSOR,
+    RPC_CMD_COPY_TENSOR,
+    RPC_CMD_GRAPH_COMPUTE,
+    RPC_CMD_GET_DEVICE_MEMORY,
+    RPC_CMD_COUNT,
 };
 
 // RPC data structures
@@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
     uint64_t remote_ptr = ctx->remote_ptr;
     memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.empty());
     delete ctx;
@@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
     uint64_t remote_ptr = ctx->remote_ptr;
     memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == sizeof(uint64_t));
     // output serialization format: | base_ptr (8 bytes) |
@@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
     memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
     memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
     GGML_ASSERT(status);
 }
 
@@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
     memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
     memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == size);
     // output serialization format: | data (size bytes) |
@@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
     memcpy(input.data(), &rpc_src, sizeof(rpc_src));
     memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
     GGML_ASSERT(status);
     // output serialization format: | result (1 byte) |
     GGML_ASSERT(output.size() == 1);
@@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
     memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
     memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
     GGML_ASSERT(status);
 }
 
@@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
     memcpy(input.data(), &size, sizeof(size));
     std::vector<uint8_t> output;
     auto sock = get_socket(buft_ctx->endpoint);
-    bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
+    bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
     // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
     // input serialization format: | 0 bytes |
     std::vector<uint8_t> input;
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == sizeof(uint64_t));
     // output serialization format: | alignment (8 bytes) |
@@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
     // input serialization format: | 0 bytes |
     std::vector<uint8_t> input;
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == sizeof(uint64_t));
     // output serialization format: | max_size (8 bytes) |
@@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
     serialize_graph(cgraph, input);
     std::vector<uint8_t> output;
     auto sock = get_socket(rpc_ctx->endpoint);
-    bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
+    bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == 1);
     return (enum ggml_status)output[0];
@@ -719,7 +720,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
     // input serialization format: | 0 bytes |
     std::vector<uint8_t> input;
     std::vector<uint8_t> output;
-    bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
     GGML_ASSERT(status);
     GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
     // output serialization format: | free (8 bytes) | total (8 bytes) |
@@ -1098,59 +1099,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
         if (!recv_data(sockfd, &cmd, 1)) {
             break;
         }
+        if (cmd >= RPC_CMD_COUNT) {
+            // fail fast if the command is invalid
+            fprintf(stderr, "Unknown command: %d\n", cmd);
+            break;
+        }
         std::vector<uint8_t> input;
         std::vector<uint8_t> output;
         uint64_t input_size;
         if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
             break;
         }
-        input.resize(input_size);
+        try {
+            input.resize(input_size);
+        } catch (const std::bad_alloc & e) {
+            fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
+            break;
+        }
         if (!recv_data(sockfd, input.data(), input_size)) {
             break;
         }
         bool ok = true;
         switch (cmd) {
-            case ALLOC_BUFFER: {
+            case RPC_CMD_ALLOC_BUFFER: {
                 ok = server.alloc_buffer(input, output);
                 break;
             }
-            case GET_ALIGNMENT: {
+            case RPC_CMD_GET_ALIGNMENT: {
                 server.get_alignment(output);
                 break;
             }
-            case GET_MAX_SIZE: {
+            case RPC_CMD_GET_MAX_SIZE: {
                 server.get_max_size(output);
                 break;
             }
-            case BUFFER_GET_BASE: {
+            case RPC_CMD_BUFFER_GET_BASE: {
                 ok = server.buffer_get_base(input, output);
                 break;
             }
-            case FREE_BUFFER: {
+            case RPC_CMD_FREE_BUFFER: {
                 ok = server.free_buffer(input);
                 break;
             }
-            case BUFFER_CLEAR: {
+            case RPC_CMD_BUFFER_CLEAR: {
                 ok = server.buffer_clear(input);
                 break;
             }
-            case SET_TENSOR: {
+            case RPC_CMD_SET_TENSOR: {
                 ok = server.set_tensor(input);
                 break;
             }
-            case GET_TENSOR: {
+            case RPC_CMD_GET_TENSOR: {
                 ok = server.get_tensor(input, output);
                 break;
             }
-            case COPY_TENSOR: {
+            case RPC_CMD_COPY_TENSOR: {
                 ok = server.copy_tensor(input, output);
                 break;
             }
-            case GRAPH_COMPUTE: {
+            case RPC_CMD_GRAPH_COMPUTE: {
                 ok = server.graph_compute(input, output);
                 break;
             }
-            case GET_DEVICE_MEMORY: {
+            case RPC_CMD_GET_DEVICE_MEMORY: {
                 // output serialization format: | free (8 bytes) | total (8 bytes) |
                 output.resize(2*sizeof(uint64_t), 0);
                 memcpy(output.data(), &free_mem, sizeof(free_mem));
@@ -1203,8 +1214,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
             return;
         }
         printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
+        fflush(stdout);
         rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
         printf("Client connection closed\n");
+        fflush(stdout);
     }
 #ifdef _WIN32
     WSACleanup();