]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
rpc : check src buffer when copying tensor (llama/16421)
authorRadoslav Gerganov <redacted>
Sat, 4 Oct 2025 13:22:45 +0000 (16:22 +0300)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 04:57:25 +0000 (07:57 +0300)
Only dst buffer is guaranteed to be an RPC buffer. Add check for the src
one.

src/ggml-rpc/ggml-rpc.cpp

index 1a8739e788e76098e408c3ac66e038944c683dec..aad48d62a850c156736fe8c4f2449adf6b72d385 100644 (file)
@@ -631,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
     RPC_STATUS_ASSERT(status);
 }
 
+static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
+    return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
+}
+
 static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
-    // check if src and dst are on the same server
-    ggml_backend_buffer_t src_buffer = src->buffer;
-    ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
-    ggml_backend_buffer_t dst_buffer = dst->buffer;
-    ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
-    if (src_ctx->sock != dst_ctx->sock) {
-        return false;
+    if (ggml_backend_buffer_is_rpc(src->buffer)) {
+        // check if src and dst are on the same server
+        ggml_backend_buffer_t src_buffer = src->buffer;
+        ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
+        ggml_backend_buffer_t dst_buffer = dst->buffer;
+        ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
+        if (src_ctx->sock != dst_ctx->sock) {
+            return false;
+        }
+        ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+        rpc_msg_copy_tensor_req request;
+        request.src = serialize_tensor(src);
+        request.dst = serialize_tensor(dst);
+        rpc_msg_copy_tensor_rsp response;
+        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
+        RPC_STATUS_ASSERT(status);
+        return response.result;
     }
-    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    rpc_msg_copy_tensor_req request;
-    request.src = serialize_tensor(src);
-    request.dst = serialize_tensor(dst);
-    rpc_msg_copy_tensor_rsp response;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
-    RPC_STATUS_ASSERT(status);
-    return response.result;
+    return false;
 }
 
 static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {