struct ggml_backend_rpc_buffer_context {
std::shared_ptr<socket_t> sock;
- std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
+ void * base_ptr;
uint64_t remote_ptr;
};
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
- return ctx->base_cache[buffer];
+ if (ctx->base_ptr != nullptr) {
+ return ctx->base_ptr;
}
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
rpc_msg_buffer_get_base_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
GGML_ASSERT(status);
- void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
- ctx->base_cache[buffer] = base_ptr;
- return base_ptr;
+ ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
+ return ctx->base_ptr;
}
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
if (response.remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
- new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
+ new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
response.remote_size);
return buffer;
} else {