]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Support for models with non-512-aligned tensors over RPC. (llama/11047)
authormatt23654 <redacted>
Sat, 4 Jan 2025 16:10:30 +0000 (16:10 +0000)
committerGeorgi Gerganov <redacted>
Tue, 14 Jan 2025 08:38:01 +0000 (10:38 +0200)
* Added init tensor calling code

* Added get_alloc_size forwarding

* Cleaned up and improved type/error handling.

* fix: remove trailing whitespaces.

* Cleanup and use GGML error logging functions.

* Handle potentially dangerous edge cases.

* Apply suggestions from code review

Co-authored-by: Diego Devesa <redacted>
---------

Co-authored-by: Diego Devesa <redacted>
ggml/src/ggml-rpc/ggml-rpc.cpp

index 43108242639a3192083809f5a242422f058c921e..2213aba9f121a6ca775c33d9719bdb0671ab1e65 100644 (file)
@@ -93,9 +93,23 @@ enum rpc_cmd {
     RPC_CMD_COPY_TENSOR,
     RPC_CMD_GRAPH_COMPUTE,
     RPC_CMD_GET_DEVICE_MEMORY,
+    RPC_CMD_INIT_TENSOR,
+    RPC_CMD_GET_ALLOC_SIZE,
     RPC_CMD_COUNT,
 };
 
+struct rpc_msg_get_alloc_size_req {
+    rpc_tensor tensor;
+};
+
+struct rpc_msg_get_alloc_size_rsp {
+    uint64_t alloc_size;
+};
+
+struct rpc_msg_init_tensor_req {
+    rpc_tensor tensor;
+};
+
 struct rpc_msg_alloc_buffer_req {
     uint64_t size;
 };
@@ -461,10 +475,18 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
 }
 
 static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
-    UNUSED(buffer);
-    if (ggml_is_quantized(tensor->type)) {
-        // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
-        GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
+    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+
+    // CUDA backend on the server pads everything to 512 due to CUDA limitations.
+    // Due to bandwidth constraints, we only call the server init tensor functions if necessary.
+    // In particular, only quantized tensors need padding
+    if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
+        rpc_msg_init_tensor_req request;
+
+        request.tensor = serialize_tensor(tensor);
+
+        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
+        GGML_ASSERT(status);
     }
 }
 
@@ -577,8 +599,23 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
 }
 
 static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    UNUSED(buft);
-    return ggml_nbytes(tensor);
+    // See comments in init_tensor.
+    if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
+        ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+        auto sock = get_socket(buft_ctx->endpoint);
+
+        rpc_msg_get_alloc_size_req request;
+
+        request.tensor = serialize_tensor(tensor);
+
+        rpc_msg_get_alloc_size_rsp response;
+        bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
+        GGML_ASSERT(status);
+
+        return response.alloc_size;
+    } else {
+        return ggml_nbytes(tensor);
+    }
 }
 
 static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -757,6 +794,8 @@ public:
     bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
     bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
     bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
+    bool init_tensor(const rpc_msg_init_tensor_req & request);
+    bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
 
 private:
     ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -770,6 +809,36 @@ private:
     std::unordered_set<ggml_backend_buffer_t> buffers;
 };
 
+bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
+    ggml_backend_buffer_type_t buft;
+    struct ggml_init_params params {
+        /*.mem_size   =*/ ggml_tensor_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+
+    struct ggml_context * ctx = ggml_init(params);
+    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
+
+    if (tensor == nullptr) {
+        GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
+        ggml_free(ctx);
+        return false;
+    }
+
+    if (tensor->buffer == nullptr) {
+        //No buffer allocated.
+        buft = ggml_backend_get_default_buffer_type(backend);
+    } else {
+        buft = tensor->buffer->buft;
+    }
+
+    response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
+
+    ggml_free(ctx);
+    return true;
+}
+
 void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
     ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
     ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
@@ -905,6 +974,40 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
     return true;
 }
 
+bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
+    struct ggml_init_params params {
+        /*.mem_size   =*/ ggml_tensor_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+    struct ggml_context * ctx = ggml_init(params);
+    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
+    if (tensor == nullptr) {
+        GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
+        ggml_free(ctx);
+        return false;
+    }
+
+    // Call the backend's buffer_init_tensor function
+    ggml_backend_buffer_t buffer = tensor->buffer;
+    if (buffer && buffer->iface.init_tensor) {
+        buffer->iface.init_tensor(buffer, tensor);
+    } else {
+        GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
+    }
+
+    if (tensor->extra != nullptr) {
+        // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
+        // Currently unimplemented.
+        GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
+        ggml_free(ctx);
+        return false;
+    }
+
+    ggml_free(ctx);
+    return true;
+}
+
 bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
     struct ggml_init_params params {
         /*.mem_size   =*/ ggml_tensor_overhead(),
@@ -1058,6 +1161,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
                 }
                 break;
             }
+            case RPC_CMD_GET_ALLOC_SIZE: {
+                rpc_msg_get_alloc_size_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                rpc_msg_get_alloc_size_rsp response;
+                server.get_alloc_size(request, response);
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
+                break;
+            }
             case RPC_CMD_GET_ALIGNMENT: {
                 if (!recv_msg(sockfd, nullptr, 0)) {
                     return;
@@ -1133,6 +1248,19 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
                 }
                 break;
             }
+            case RPC_CMD_INIT_TENSOR: {
+                rpc_msg_init_tensor_req request;
+                if (!recv_msg(sockfd, &request,sizeof(request))) {
+                    return;
+                }
+                if (!server.init_tensor(request)) {
+                    return;
+                }
+                if (!send_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
+                break;
+            }
             case RPC_CMD_GET_TENSOR: {
                 rpc_msg_get_tensor_req request;
                 if (!recv_msg(sockfd, &request, sizeof(request))) {