bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
+ bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
return true;
}
+bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
+ uint32_t dev_id = request.device;
+ if (dev_id >= backends.size()) {
+ return false;
+ }
+ size_t free, total;
+ ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
+ ggml_backend_dev_memory(dev, &free, &total);
+ response.free_mem = free;
+ response.total_mem = total;
+ LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
+ return true;
+}
+
rpc_server::~rpc_server() {
for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer);
}
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
- sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
+ sockfd_t sockfd) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
- auto dev_id = request.device;
- if (dev_id >= backends.size()) {
+ rpc_msg_get_device_memory_rsp response;
+ if (!server.get_device_memory(request, response)) {
return;
}
- rpc_msg_get_device_memory_rsp response;
- response.free_mem = free_mem[dev_id];
- response.total_mem = total_mem[dev_id];
- LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
- response.free_mem, response.total_mem);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
}
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
- size_t n_threads, size_t n_devices,
- ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
- if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
+ size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
+ if (n_devices == 0 || devices == nullptr) {
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
return;
}
std::vector<ggml_backend_t> backends;
- std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
- std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
printf("Devices:\n");
for (size_t i = 0; i < n_devices; i++) {
auto dev = devices[i];
+ size_t free, total;
+ ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
- total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
+ total / 1024 / 1024, free / 1024 / 1024);
auto backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
}
printf("Accepted client connection\n");
fflush(stdout);
- rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
+ rpc_serve_client(backends, cache_dir, client_socket->fd);
printf("Client connection closed\n");
fflush(stdout);
}