RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
+ RPC_CMD_GRAPH_RECOMPUTE,
RPC_CMD_COUNT,
};
uint8_t result;
};
-struct rpc_msg_graph_compute_rsp {
- uint8_t result;
-};
-
struct rpc_msg_get_device_memory_req {
uint32_t device;
};
uint64_t free_mem;
uint64_t total_mem;
};
+
+struct rpc_msg_graph_recompute_req {
+ uint32_t device;
+};
+
#pragma pack(pop)
// RPC data structures
size_t max_size;
};
+struct graph_cache {
+
+ bool is_cached(const ggml_cgraph * cgraph) {
+ if ((int)last_graph.size() != cgraph->n_nodes) {
+ return false;
+ }
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ void add(const ggml_cgraph * cgraph) {
+ last_graph.resize(cgraph->n_nodes);
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
+ }
+ }
+
+ std::vector<ggml_tensor> last_graph;
+};
+
struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
+ graph_cache gc;
};
struct ggml_backend_rpc_buffer_context {
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
- std::vector<uint8_t> input;
- serialize_graph(rpc_ctx->device, cgraph, input);
- rpc_msg_graph_compute_rsp response;
- auto sock = get_socket(rpc_ctx->endpoint);
- bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
- RPC_STATUS_ASSERT(status);
- return (enum ggml_status)response.result;
+
+ GGML_ASSERT(cgraph->n_nodes > 0);
+ bool reuse = rpc_ctx->gc.is_cached(cgraph);
+ if (reuse) {
+ rpc_msg_graph_recompute_req request;
+ request.device = rpc_ctx->device;
+ auto sock = get_socket(rpc_ctx->endpoint);
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
+ RPC_STATUS_ASSERT(status);
+ } else {
+ rpc_ctx->gc.add(cgraph);
+ std::vector<uint8_t> input;
+ serialize_graph(rpc_ctx->device, cgraph, input);
+ auto sock = get_socket(rpc_ctx->endpoint);
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
+ RPC_STATUS_ASSERT(status);
+ }
+ return GGML_STATUS_SUCCESS;
}
static ggml_backend_i ggml_backend_rpc_interface = {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
- /* .name = */ dev_name
+ /* .name = */ dev_name,
+ /* .gc = */ {},
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
class rpc_server {
public:
- rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
- : backends(std::move(backends)), cache_dir(cache_dir) {
+ rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
+ : backends(std::move(all_backends)), cache_dir(cache_dir) {
+ stored_graphs.resize(backends.size());
}
~rpc_server();
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
- bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
+ bool graph_compute(const std::vector<uint8_t> & input);
+ bool graph_recompute(const rpc_msg_graph_recompute_req & request);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
+ struct stored_graph {
+ ggml_context_ptr ctx_ptr;
+ ggml_cgraph * graph;
+ };
+
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
std::vector<ggml_backend_t> backends;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
+ // store the last computed graph for each backend
+ std::vector<stored_graph> stored_graphs;
};
void rpc_server::hello(rpc_msg_hello_rsp & response) {
return result;
}
-bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
+bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) {
}
}
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
- response.result = status;
+ GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
+ stored_graphs[device].ctx_ptr.swap(ctx_ptr);
+ stored_graphs[device].graph = graph;
+ return true;
+}
+
+bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
+ uint32_t device = request.device;
+ if (device >= backends.size()) {
+ return false;
+ }
+ if (stored_graphs[device].graph == nullptr) {
+ return false;
+ }
+ ggml_cgraph * graph = stored_graphs[device].graph;
+ LOG_DBG("[%s] device: %u\n", __func__, device);
+ ggml_status status = ggml_backend_graph_compute(backends[device], graph);
+ GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
return true;
}
if (!recv_msg(sockfd, input)) {
return;
}
- rpc_msg_graph_compute_rsp response;
- if (!server.graph_compute(input, response)) {
+ if (!server.graph_compute(input)) {
return;
}
- if (!send_msg(sockfd, &response, sizeof(response))) {
+ break;
+ }
+ case RPC_CMD_GRAPH_RECOMPUTE: {
+ rpc_msg_graph_recompute_req request;
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
+ return;
+ }
+ if (!server.graph_recompute(request)) {
return;
}
break;