RPC_STATUS_ASSERT(status);
}
+static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
+ return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
+}
+
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
- // check if src and dst are on the same server
- ggml_backend_buffer_t src_buffer = src->buffer;
- ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
- ggml_backend_buffer_t dst_buffer = dst->buffer;
- ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
- if (src_ctx->sock != dst_ctx->sock) {
- return false;
+ if (ggml_backend_buffer_is_rpc(src->buffer)) {
+ // check if src and dst are on the same server
+ ggml_backend_buffer_t src_buffer = src->buffer;
+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
+ ggml_backend_buffer_t dst_buffer = dst->buffer;
+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
+ if (src_ctx->sock != dst_ctx->sock) {
+ return false;
+ }
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ rpc_msg_copy_tensor_req request;
+ request.src = serialize_tensor(src);
+ request.dst = serialize_tensor(dst);
+ rpc_msg_copy_tensor_rsp response;
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
+ RPC_STATUS_ASSERT(status);
+ return response.result;
}
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- rpc_msg_copy_tensor_req request;
- request.src = serialize_tensor(src);
- request.dst = serialize_tensor(dst);
- rpc_msg_copy_tensor_rsp response;
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
- RPC_STATUS_ASSERT(status);
- return response.result;
+ return false;
}
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {