]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
rpc : add support for multiple devices (#16276)
authorRadoslav Gerganov <redacted>
Sat, 4 Oct 2025 09:49:16 +0000 (12:49 +0300)
committerGitHub <redacted>
Sat, 4 Oct 2025 09:49:16 +0000 (12:49 +0300)
* rpc : add support for multiple devices

Allow rpc-server to expose multiple devices from a single endpoint.
Change RPC protocol to include device identifier where needed.

closes: #15210

* fixes

* use ggml_backend_reg_t

* address review comments

* fix llama-bench backend report

* address review comments, change device naming

* fix cmd order

common/arg.cpp
ggml/include/ggml-backend.h
ggml/include/ggml-rpc.h
ggml/src/ggml-backend-impl.h
ggml/src/ggml-rpc/ggml-rpc.cpp
tools/llama-bench/llama-bench.cpp
tools/rpc/rpc-server.cpp

index 577048c201b7692bdb9586c188038da2509bafa9..a020ac44132907ba808cc1ae6b0abaf11c75c484 100644 (file)
@@ -1615,18 +1615,14 @@ static void add_rpc_devices(const std::string & servers) {
     if (!rpc_reg) {
         throw std::invalid_argument("failed to find RPC backend");
     }
-    typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
-    ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
-    if (!ggml_backend_rpc_add_device_fn) {
-        throw std::invalid_argument("failed to find RPC device add function");
+    typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint);
+    ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
+    if (!ggml_backend_rpc_add_server_fn) {
+        throw std::invalid_argument("failed to find RPC add server function");
     }
     for (const auto & server : rpc_servers) {
-        ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
-        if (dev) {
-            ggml_backend_device_register(dev);
-        } else {
-            throw std::invalid_argument("failed to register RPC device");
-        }
+        auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
+        ggml_backend_register(reg);
     }
 }
 
index 62b6d65e514452d129caf0e1af895a548b9fcd75..f1b740785914ed3577c6517937af0a34d4bcaaec 100644 (file)
@@ -215,6 +215,8 @@ extern "C" {
     // Backend registry
     //
 
+    GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
+
     GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
 
     // Backend (reg) enumeration
index 1e674112767c9e9f6d641aa3c5fc2ff1629efa69..72eff0027351a91d7456f6340710154f5bea2ae0 100644 (file)
@@ -7,26 +7,25 @@
 extern "C" {
 #endif
 
-#define RPC_PROTO_MAJOR_VERSION    2
+#define RPC_PROTO_MAJOR_VERSION    3
 #define RPC_PROTO_MINOR_VERSION    0
 #define RPC_PROTO_PATCH_VERSION    0
 #define GGML_RPC_MAX_SERVERS       16
 
 // backend API
-GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
+GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
 GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
 
-GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);
 
-GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
+GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
 
-GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
-                                                    const char * cache_dir,
-                                                    size_t free_mem, size_t total_mem);
+GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
+                                                    size_t n_threads, size_t n_devices,
+                                                    ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
 
 GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
-
-GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
 
 #ifdef  __cplusplus
 }
index 07784d6f66ce666788674fe3df49125f57bbcc5f..6792ba986e8ed5041de450ddaab7fe2ba6af8446 100644 (file)
@@ -209,9 +209,6 @@ extern "C" {
         void * context;
     };
 
-    // Internal backend registry API
-    GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
-
     // Add backend dynamic loading support to the backend
 
     // Initialize the backend
index f99681c84cbabeec9f622815630dabc09cca1de3..1a8739e788e76098e408c3ac66e038944c683dec 100644 (file)
@@ -105,9 +105,12 @@ enum rpc_cmd {
     RPC_CMD_INIT_TENSOR,
     RPC_CMD_GET_ALLOC_SIZE,
     RPC_CMD_HELLO,
+    RPC_CMD_DEVICE_COUNT,
     RPC_CMD_COUNT,
 };
 
+static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
+
 // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
 const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
 
@@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp {
     uint8_t patch;
 };
 
+struct rpc_msg_device_count_rsp {
+    uint32_t device_count;
+};
+
 struct rpc_msg_get_alloc_size_req {
+    uint32_t   device;
     rpc_tensor tensor;
 };
 
@@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req {
 };
 
 struct rpc_msg_alloc_buffer_req {
+    uint32_t device;
     uint64_t size;
 };
 
@@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp {
     uint64_t remote_size;
 };
 
+struct rpc_msg_get_alignment_req {
+    uint32_t device;
+};
+
 struct rpc_msg_get_alignment_rsp {
     uint64_t alignment;
 };
 
+struct rpc_msg_get_max_size_req {
+    uint32_t device;
+};
+
 struct rpc_msg_get_max_size_rsp {
     uint64_t max_size;
 };
@@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp {
     uint8_t result;
 };
 
+struct rpc_msg_get_device_memory_req {
+    uint32_t device;
+};
+
 struct rpc_msg_get_device_memory_rsp {
     uint64_t free_mem;
     uint64_t total_mem;
@@ -207,13 +228,15 @@ static ggml_guid_t ggml_backend_rpc_guid() {
 
 struct ggml_backend_rpc_buffer_type_context {
     std::string endpoint;
+    uint32_t    device;
     std::string name;
-    size_t alignment;
-    size_t max_size;
+    size_t      alignment;
+    size_t      max_size;
 };
 
 struct ggml_backend_rpc_context {
     std::string endpoint;
+    uint32_t    device;
     std::string name;
 };
 
@@ -653,7 +676,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
 
 static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    rpc_msg_alloc_buffer_req request = {size};
+    rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
     rpc_msg_alloc_buffer_rsp response;
     auto sock = get_socket(buft_ctx->endpoint);
     bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
@@ -669,9 +692,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
     }
 }
 
-static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
+static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
+    rpc_msg_get_alignment_req request = {device};
     rpc_msg_get_alignment_rsp response;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
     RPC_STATUS_ASSERT(status);
     return response.alignment;
 }
@@ -681,9 +705,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
     return buft_ctx->alignment;
 }
 
-static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
+static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
+    rpc_msg_get_max_size_req request = {device};
     rpc_msg_get_max_size_rsp response;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
     RPC_STATUS_ASSERT(status);
     return response.max_size;
 }
@@ -700,7 +725,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
         auto sock = get_socket(buft_ctx->endpoint);
 
         rpc_msg_get_alloc_size_req request;
-
+        request.device = buft_ctx->device;
         request.tensor = serialize_tensor(tensor);
 
         rpc_msg_get_alloc_size_rsp response;
@@ -754,7 +779,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
     tensors.push_back(serialize_tensor(tensor));
 }
 
-static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
+static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
     uint32_t n_nodes = cgraph->n_nodes;
     std::vector<rpc_tensor> tensors;
     std::unordered_set<ggml_tensor*> visited;
@@ -762,24 +787,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
         add_tensor(cgraph->nodes[i], tensors, visited);
     }
     // serialization format:
-    // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
+    // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
     uint32_t n_tensors = tensors.size();
-    int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
+    int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
     output.resize(output_size, 0);
-    memcpy(output.data(), &n_nodes, sizeof(n_nodes));
+    uint8_t * dest = output.data();
+    memcpy(dest, &device, sizeof(device));
+    dest += sizeof(device);
+    memcpy(dest, &n_nodes, sizeof(n_nodes));
+    dest += sizeof(n_nodes);
     for (uint32_t i = 0; i < n_nodes; i++) {
-        memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
+        memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
     }
-    uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
-    *out_ntensors = n_tensors;
-    rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
+    dest += n_nodes * sizeof(uint64_t);
+    memcpy(dest, &n_tensors, sizeof(n_tensors));
+    dest += sizeof(n_tensors);
+    rpc_tensor * out_tensors = (rpc_tensor *)dest;
     memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
 }
 
 static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
     std::vector<uint8_t> input;
-    serialize_graph(cgraph, input);
+    serialize_graph(rpc_ctx->device, cgraph, input);
     rpc_msg_graph_compute_rsp response;
     auto sock = get_socket(rpc_ctx->endpoint);
     bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
@@ -804,12 +834,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
     /* .graph_optimize          = */ NULL,
 };
 
-ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
+ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
     static std::mutex mutex;
     std::lock_guard<std::mutex> lock(mutex);
+    std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
     // NOTE: buffer types are allocated and never freed; this is by design
     static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
-    auto it = buft_map.find(endpoint);
+    auto it = buft_map.find(buft_name);
     if (it != buft_map.end()) {
         return it->second;
     }
@@ -818,34 +849,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
         GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
         return nullptr;
     }
-    size_t alignment = get_alignment(sock);
-    size_t max_size = get_max_size(sock);
+    size_t alignment = get_alignment(sock, device);
+    size_t max_size = get_max_size(sock, device);
     ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
         /* .endpoint  = */ endpoint,
-        /* .name      = */ "RPC[" + std::string(endpoint) + "]",
+        /* .device    = */ device,
+        /* .name      = */ buft_name,
         /* .alignment = */ alignment,
         /* .max_size  = */ max_size
     };
-
+    auto reg = ggml_backend_rpc_add_server(endpoint);
     ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
         /* .iface   = */ ggml_backend_rpc_buffer_type_interface,
-        /* .device  = */ ggml_backend_rpc_add_device(endpoint),
+        /* .device  = */ ggml_backend_reg_dev_get(reg, device),
         /* .context = */ buft_ctx
     };
-    buft_map[endpoint] = buft;
+    buft_map[buft_name] = buft;
     return buft;
 }
 
-ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
+ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
+    std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
     ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
-        /* .endpoint  = */ endpoint,
-        /* .name      = */ "RPC[" + std::string(endpoint) + "]",
+        /* .endpoint = */ endpoint,
+        /* .device   = */ device,
+        /* .name     = */ dev_name
     };
-
+    auto reg = ggml_backend_rpc_add_server(endpoint);
     ggml_backend_t backend = new ggml_backend {
         /* .guid    = */ ggml_backend_rpc_guid(),
         /* .iface   = */ ggml_backend_rpc_interface,
-        /* .device  = */ ggml_backend_rpc_add_device(endpoint),
+        /* .device  = */ ggml_backend_reg_dev_get(reg, device),
         /* .context = */ ctx
     };
     return backend;
@@ -855,37 +889,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
     return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
 }
 
-static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
+static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
+    rpc_msg_get_device_memory_req request;
+    request.device = device;
     rpc_msg_get_device_memory_rsp response;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
     RPC_STATUS_ASSERT(status);
     *free = response.free_mem;
     *total = response.total_mem;
 }
 
-void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
+void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
     auto sock = get_socket(endpoint);
     if (sock == nullptr) {
         *free = 0;
         *total = 0;
         return;
     }
-    get_device_memory(sock, free, total);
+    get_device_memory(sock, device, free, total);
 }
 
 // RPC server-side implementation
 
 class rpc_server {
 public:
-    rpc_server(ggml_backend_t backend, const char * cache_dir)
-        : backend(backend), cache_dir(cache_dir) {
+    rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
+        : backends(std::move(backends)), cache_dir(cache_dir) {
     }
     ~rpc_server();
 
     void hello(rpc_msg_hello_rsp & response);
-    void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
-    void get_alignment(rpc_msg_get_alignment_rsp & response);
-    void get_max_size(rpc_msg_get_max_size_rsp & response);
+    bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
+    bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
+    bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
     bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
     bool free_buffer(const rpc_msg_free_buffer_req & request);
     bool buffer_clear(const rpc_msg_buffer_clear_req & request);
@@ -906,7 +942,7 @@ private:
                               std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
 
 
-    ggml_backend_t backend;
+    std::vector<ggml_backend_t> backends;
     const char * cache_dir;
     std::unordered_set<ggml_backend_buffer_t> buffers;
 };
@@ -919,6 +955,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
 }
 
 bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
+    uint32_t dev_id = request.device;
+    if (dev_id >= backends.size()) {
+        return false;
+    }
     ggml_backend_buffer_type_t buft;
     struct ggml_init_params params {
         /*.mem_size   =*/ ggml_tensor_overhead(),
@@ -935,10 +975,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
         GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
         return false;
     }
-    LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
+    LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
     if (tensor->buffer == nullptr) {
         //No buffer allocated.
-        buft = ggml_backend_get_default_buffer_type(backend);
+        buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
     } else {
         buft = tensor->buffer->buft;
     }
@@ -948,33 +988,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
     return true;
 }
 
-void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
-    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
+    uint32_t dev_id = request.device;
+    if (dev_id >= backends.size()) {
+        return false;
+    }
+    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
     ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
     response.remote_ptr = 0;
     response.remote_size = 0;
     if (buffer != nullptr) {
         response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
         response.remote_size = buffer->size;
-        LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
+        LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
+            __func__, dev_id, request.size, response.remote_ptr, response.remote_size);
         buffers.insert(buffer);
     } else {
-        LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
+        LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
     }
+    return true;
 }
 
-void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
-    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
+    uint32_t dev_id = request.device;
+    if (dev_id >= backends.size()) {
+        return false;
+    }
+    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
     size_t alignment = ggml_backend_buft_get_alignment(buft);
-    LOG_DBG("[%s] alignment: %lu\n", __func__, alignment);
+    LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
     response.alignment = alignment;
+    return true;
 }
 
-void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
-    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
+    uint32_t dev_id = request.device;
+    if (dev_id >= backends.size()) {
+        return false;
+    }
+    ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
     size_t max_size = ggml_backend_buft_get_max_size(buft);
-    LOG_DBG("[%s] max_size: %lu\n", __func__, max_size);
+    LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
     response.max_size = max_size;
+    return true;
 }
 
 bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
@@ -1332,23 +1388,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
 
 bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
     // serialization format:
-    // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
-    if (input.size() < sizeof(uint32_t)) {
+    // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
+    if (input.size() < 2*sizeof(uint32_t)) {
+        return false;
+    }
+    const uint8_t * src = input.data();
+    uint32_t device;
+    memcpy(&device, src, sizeof(device));
+    src += sizeof(device);
+    if (device >= backends.size()) {
         return false;
     }
     uint32_t n_nodes;
-    memcpy(&n_nodes, input.data(), sizeof(n_nodes));
-    if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
+    memcpy(&n_nodes, src, sizeof(n_nodes));
+    src += sizeof(n_nodes);
+    if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
         return false;
     }
-    const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
+    const uint64_t * nodes = (const uint64_t *)src;
+    src += n_nodes*sizeof(uint64_t);
     uint32_t n_tensors;
-    memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
-    if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
+    memcpy(&n_tensors, src, sizeof(n_tensors));
+    src += sizeof(n_tensors);
+    if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
         return false;
     }
-    const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
-    LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
+    const rpc_tensor * tensors = (const rpc_tensor *)src;
+    LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
 
     size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
 
@@ -1380,7 +1446,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
             return false;
         }
     }
-    ggml_status status = ggml_backend_graph_compute(backend, graph);
+    ggml_status status = ggml_backend_graph_compute(backends[device], graph);
     response.result = status;
     return true;
 }
@@ -1391,9 +1457,9 @@ rpc_server::~rpc_server() {
     }
 }
 
-static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
-                             sockfd_t sockfd, size_t free_mem, size_t total_mem) {
-    rpc_server server(backend, cache_dir);
+static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
+                             sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
+    rpc_server server(backends, cache_dir);
     uint8_t cmd;
     if (!recv_data(sockfd, &cmd, 1)) {
         return;
@@ -1425,13 +1491,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
                 // HELLO command is handled above
                 return;
             }
+            case RPC_CMD_DEVICE_COUNT: {
+                if (!recv_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
+                rpc_msg_device_count_rsp response;
+                response.device_count = backends.size();
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
+                break;
+            }
             case RPC_CMD_ALLOC_BUFFER: {
                 rpc_msg_alloc_buffer_req request;
                 if (!recv_msg(sockfd, &request, sizeof(request))) {
                     return;
                 }
                 rpc_msg_alloc_buffer_rsp response;
-                server.alloc_buffer(request, response);
+                if (!server.alloc_buffer(request, response)) {
+                    return;
+                }
                 if (!send_msg(sockfd, &response, sizeof(response))) {
                     return;
                 }
@@ -1452,22 +1531,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
                 break;
             }
             case RPC_CMD_GET_ALIGNMENT: {
-                if (!recv_msg(sockfd, nullptr, 0)) {
+                rpc_msg_get_alignment_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
                     return;
                 }
                 rpc_msg_get_alignment_rsp response;
-                server.get_alignment(response);
+                if (!server.get_alignment(request, response)) {
+                    return;
+                }
                 if (!send_msg(sockfd, &response, sizeof(response))) {
                     return;
                 }
                 break;
             }
             case RPC_CMD_GET_MAX_SIZE: {
-                if (!recv_msg(sockfd, nullptr, 0)) {
+                rpc_msg_get_max_size_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
                     return;
                 }
                 rpc_msg_get_max_size_rsp response;
-                server.get_max_size(response);
+                if (!server.get_max_size(request, response)) {
+                    return;
+                }
                 if (!send_msg(sockfd, &response, sizeof(response))) {
                     return;
                 }
@@ -1593,12 +1678,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
                 break;
             }
             case RPC_CMD_GET_DEVICE_MEMORY: {
-                if (!recv_msg(sockfd, nullptr, 0)) {
+                rpc_msg_get_device_memory_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                auto dev_id = request.device;
+                if (dev_id >= backends.size()) {
                     return;
                 }
                 rpc_msg_get_device_memory_rsp response;
-                response.free_mem = free_mem;
-                response.total_mem = total_mem;
+                response.free_mem = free_mem[dev_id];
+                response.total_mem = total_mem[dev_id];
+                LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
+                    response.free_mem, response.total_mem);
                 if (!send_msg(sockfd, &response, sizeof(response))) {
                     return;
                 }
@@ -1612,16 +1704,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
     }
 }
 
-void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
-                                   const char * cache_dir,
-                                   size_t free_mem, size_t total_mem) {
+void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
+                                   size_t n_threads, size_t n_devices,
+                                   ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
+    if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
+        fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
+        return;
+    }
+    std::vector<ggml_backend_t> backends;
+    std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
+    std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
     printf("Starting RPC server v%d.%d.%d\n",
         RPC_PROTO_MAJOR_VERSION,
         RPC_PROTO_MINOR_VERSION,
         RPC_PROTO_PATCH_VERSION);
     printf("  endpoint       : %s\n", endpoint);
     printf("  local cache    : %s\n", cache_dir ? cache_dir : "n/a");
-    printf("  backend memory : %zu MB\n", free_mem / (1024 * 1024));
+    printf("Devices:\n");
+    for (size_t i = 0; i < n_devices; i++) {
+        auto dev = devices[i];
+        printf("  %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
+               total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
+        auto backend = ggml_backend_dev_init(dev, nullptr);
+        if (!backend) {
+            fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
+            return;
+        }
+        backends.push_back(backend);
+        ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
+        if (reg) {
+            auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+            if (ggml_backend_set_n_threads_fn) {
+                ggml_backend_set_n_threads_fn(backend, n_threads);
+            }
+        }
+    }
 
     std::string host;
     int port;
@@ -1649,22 +1766,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
             fprintf(stderr, "Failed to accept client connection\n");
             return;
         }
-        printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
+        printf("Accepted client connection\n");
         fflush(stdout);
-        rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
+        rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
         printf("Client connection closed\n");
         fflush(stdout);
     }
 #ifdef _WIN32
     WSACleanup();
 #endif
+    for (auto backend : backends) {
+        ggml_backend_free(backend);
+    }
 }
 
 // device interface
 
 struct ggml_backend_rpc_device_context {
     std::string endpoint;
+    uint32_t    device;
     std::string name;
+    std::string description;
 };
 
 static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
@@ -1676,15 +1798,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
 static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
 
-    return ctx->name.c_str();
+    return ctx->description.c_str();
 }
 
 static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
 
-    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
-
-    GGML_UNUSED(dev);
+    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
 }
 
 static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
@@ -1710,7 +1830,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
 static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
 
-    return ggml_backend_rpc_init(ctx->endpoint.c_str());
+    return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
 
     GGML_UNUSED(params);
 }
@@ -1718,7 +1838,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
 static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
 
-    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
+    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
 
     GGML_UNUSED(dev);
 }
@@ -1736,7 +1856,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
     }
     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
     ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
-    return buft_ctx->endpoint == dev_ctx->endpoint;
+    return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
 }
 
 static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
@@ -1759,28 +1879,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
 
 // backend reg interface
 
-static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
-    return "RPC";
+struct ggml_backend_rpc_reg_context {
+    std::string                     name;
+    std::vector<ggml_backend_dev_t> devices;
+};
 
-    GGML_UNUSED(reg);
+static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
+    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
+    return ctx ? ctx->name.c_str() : "RPC";
 }
 
 static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
-    return 0;
-
-    GGML_UNUSED(reg);
+    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
+    return ctx ? ctx->devices.size() : 0;
 }
 
 static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
-    GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
-
-    GGML_UNUSED(reg);
-    GGML_UNUSED(index);
+    ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
+    if (ctx == nullptr) {
+        GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
+    } else {
+        GGML_ASSERT(index < ctx->devices.size());
+        return ctx->devices[index];
+    }
 }
 
 static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
-    if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
-        return (void *)ggml_backend_rpc_add_device;
+    if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
+        return (void *)ggml_backend_rpc_add_server;
     }
     if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
         return (void *)ggml_backend_rpc_start_server;
@@ -1807,30 +1933,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
     return &ggml_backend_rpc_reg;
 }
 
-ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
-    static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
+static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
+    auto sock = get_socket(endpoint);
+    rpc_msg_device_count_rsp response;
+    bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
+    RPC_STATUS_ASSERT(status);
+    return response.device_count;
+}
 
+static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
+    /* .get_name          = */ ggml_backend_rpc_reg_get_name,
+    /* .get_device_count  = */ ggml_backend_rpc_reg_get_device_count,
+    /* .get_device        = */ ggml_backend_rpc_reg_get_device,
+    /* .get_proc_address  = */ ggml_backend_rpc_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
+    static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
     static std::mutex mutex;
+    static uint32_t dev_id = 0;
     std::lock_guard<std::mutex> lock(mutex);
-
-    if (dev_map.find(endpoint) != dev_map.end()) {
-        return dev_map[endpoint];
+    if (reg_map.find(endpoint) != reg_map.end()) {
+        return reg_map[endpoint];
     }
-
-    ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
-        /* .endpoint = */ endpoint,
-        /* .name     = */ "RPC[" + std::string(endpoint) + "]",
-    };
-
-    ggml_backend_dev_t dev = new ggml_backend_device {
-        /* .iface   = */ ggml_backend_rpc_device_i,
-        /* .reg     = */ ggml_backend_rpc_reg(),
-        /* .context = */ ctx,
+    uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
+    if (dev_count == 0) {
+        return nullptr;
+    }
+    ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
+    ctx->name = "RPC[" + std::string(endpoint) + "]";
+    for (uint32_t ind = 0; ind < dev_count; ind++) {
+        std::string dev_name = "RPC" + std::to_string(dev_id);
+        std::string dev_desc = std::string(endpoint);
+        ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
+            /* .endpoint    = */ endpoint,
+            /* .device      = */ ind,
+            /* .name        = */ dev_name,
+            /* .description = */ dev_desc
+        };
+
+        ggml_backend_dev_t dev = new ggml_backend_device {
+            /* .iface   = */ ggml_backend_rpc_device_i,
+            /* .reg     = */ ggml_backend_rpc_reg(),
+            /* .context = */ dev_ctx,
+        };
+        ctx->devices.push_back(dev);
+        dev_id++;
+    }
+    ggml_backend_reg_t reg = new ggml_backend_reg {
+        /* .api_version = */ GGML_BACKEND_API_VERSION,
+        /* .iface       = */ ggml_backend_rpc_reg_interface,
+        /* .context     = */ ctx
     };
-
-    dev_map[endpoint] = dev;
-
-    return dev;
+    reg_map[endpoint] = reg;
+    return reg;
 }
 
+
 GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
index 275ba367c02f1d6e03dfbe1f56a48766f8eeda1f..89bc01b48546c220163185192154e8350e1a2f31 100644 (file)
@@ -168,7 +168,7 @@ static std::vector<ggml_backend_dev_t> parse_devices_arg(const std::string & val
     return devices;
 }
 
-static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::string & servers) {
+static void register_rpc_server_list(const std::string & servers) {
     auto rpc_servers = string_split<std::string>(servers, ',');
     if (rpc_servers.empty()) {
         throw std::invalid_argument("no RPC servers specified");
@@ -179,36 +179,15 @@ static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::strin
         throw std::invalid_argument("failed to find RPC backend");
     }
 
-    using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint);
-    auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
-    if (!ggml_backend_rpc_add_device_fn) {
-        throw std::invalid_argument("failed to find RPC device add function");
+    using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint);
+    auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
+    if (!ggml_backend_rpc_add_server_fn) {
+        throw std::invalid_argument("failed to find RPC add server function");
     }
-
-    static std::unordered_set<std::string> registered;
-    std::vector<ggml_backend_dev_t> devices;
     for (const auto & server : rpc_servers) {
-        ggml_backend_dev_t dev = nullptr;
-
-        std::string name = string_format("RPC[%s]", server.c_str());
-
-        if (registered.find(server) != registered.end()) {
-            dev = ggml_backend_dev_by_name(name.c_str());
-        }
-
-        if (!dev) {
-            dev = ggml_backend_rpc_add_device_fn(server.c_str());
-            if (!dev) {
-                throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str()));
-            }
-            ggml_backend_device_register(dev);
-            registered.insert(server);
-        }
-
-        devices.push_back(dev);
+        auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
+        ggml_backend_register(reg);
     }
-
-    return devices;
 }
 
 static std::string devices_to_string(const std::vector<ggml_backend_dev_t> & devices) {
@@ -714,7 +693,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
                     break;
                 }
                 try {
-                    register_rpc_device_list(argv[i]);
+                    register_rpc_server_list(argv[i]);
                 } catch (const std::exception & e) {
                     fprintf(stderr, "error: %s\n", e.what());
                     invalid_param = true;
@@ -1368,13 +1347,23 @@ struct test {
 
     static std::string get_backend() {
         std::vector<std::string> backends;
+        bool                     rpc_used = false;
         for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
             auto *      reg  = ggml_backend_reg_get(i);
             std::string name = ggml_backend_reg_name(reg);
-            if (name != "CPU") {
-                backends.push_back(ggml_backend_reg_name(reg));
+            if (string_starts_with(name, "RPC")) {
+                if (ggml_backend_reg_dev_count(reg) > 0) {
+                    rpc_used = true;
+                }
+            } else {
+                if (name != "CPU") {
+                    backends.push_back(ggml_backend_reg_name(reg));
+                }
             }
         }
+        if (rpc_used) {
+            backends.push_back("RPC");
+        }
         return backends.empty() ? "CPU" : join(backends, ",");
     }
 
index dc8e077f34a73412a144f9ec32c87ea43c72ea7e..088515612772ddae0e596ed3f8336060078e4e0a 100644 (file)
@@ -22,6 +22,7 @@
 #include <filesystem>
 #include <algorithm>
 #include <thread>
+#include <regex>
 
 namespace fs = std::filesystem;
 
@@ -131,24 +132,24 @@ static std::string fs_get_cache_directory() {
 }
 
 struct rpc_server_params {
-    std::string host        = "127.0.0.1";
-    int         port        = 50052;
-    size_t      backend_mem = 0;
-    bool        use_cache   = false;
-    int         n_threads   = std::max(1U, std::thread::hardware_concurrency()/2);
-    std::string device;
+    std::string              host        = "127.0.0.1";
+    int                      port        = 50052;
+    bool                     use_cache   = false;
+    int                      n_threads   = std::max(1U, std::thread::hardware_concurrency()/2);
+    std::vector<std::string> devices;
+    std::vector<size_t>      dev_mem;
 };
 
 static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
     fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
     fprintf(stderr, "options:\n");
-    fprintf(stderr, "  -h, --help                show this help message and exit\n");
-    fprintf(stderr, "  -t,      --threads        number of threads for the CPU backend (default: %d)\n", params.n_threads);
-    fprintf(stderr, "  -d DEV,  --device         device to use\n");
-    fprintf(stderr, "  -H HOST, --host HOST      host to bind to (default: %s)\n", params.host.c_str());
-    fprintf(stderr, "  -p PORT, --port PORT      port to bind to (default: %d)\n", params.port);
-    fprintf(stderr, "  -m MEM,  --mem MEM        backend memory size (in MB)\n");
-    fprintf(stderr, "  -c,      --cache          enable local file cache\n");
+    fprintf(stderr, "  -h, --help                       show this help message and exit\n");
+    fprintf(stderr, "  -t, --threads N                  number of threads for the CPU device (default: %d)\n", params.n_threads);
+    fprintf(stderr, "  -d, --device <dev1,dev2,...>     comma-separated list of devices\n");
+    fprintf(stderr, "  -H, --host HOST                  host to bind to (default: %s)\n", params.host.c_str());
+    fprintf(stderr, "  -p, --port PORT                  port to bind to (default: %d)\n", params.port);
+    fprintf(stderr, "  -m, --mem <M1,M2,...>            memory size for each device (in MB)\n");
+    fprintf(stderr, "  -c, --cache                      enable local file cache\n");
     fprintf(stderr, "\n");
 }
 
@@ -174,17 +175,17 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
             if (++i >= argc) {
                 return false;
             }
-            params.device = argv[i];
-            if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) {
-                fprintf(stderr, "error: unknown device: %s\n", params.device.c_str());
-                fprintf(stderr, "available devices:\n");
-                for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
-                    auto * dev = ggml_backend_dev_get(i);
-                    size_t free, total;
-                    ggml_backend_dev_memory(dev, &free, &total);
-                    printf("  %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
+            const std::regex regex{ R"([,/]+)" };
+            std::string dev_str = argv[i];
+            std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1);
+            std::sregex_token_iterator end;
+            for ( ; iter != end; ++iter) {
+                try {
+                    params.devices.push_back(*iter);
+                } catch (const std::exception & ) {
+                    fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str());
+                    return false;
                 }
-                return false;
             }
         } else if (arg == "-p" || arg == "--port") {
             if (++i >= argc) {
@@ -200,7 +201,19 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
             if (++i >= argc) {
                 return false;
             }
-            params.backend_mem = std::stoul(argv[i]) * 1024 * 1024;
+            const std::regex regex{ R"([,/]+)" };
+            std::string mem_str = argv[i];
+            std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
+            std::sregex_token_iterator end;
+            for ( ; iter != end; ++iter) {
+                try {
+                    size_t mem = std::stoul(*iter) * 1024 * 1024;
+                    params.dev_mem.push_back(mem);
+                } catch (const std::exception & ) {
+                    fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
+                    return false;
+                }
+            }
         } else if (arg == "-h" || arg == "--help") {
             print_usage(argc, argv, params);
             exit(0);
@@ -213,45 +226,46 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
     return true;
 }
 
-static ggml_backend_t create_backend(const rpc_server_params & params) {
-    ggml_backend_t backend = nullptr;
-
-    if (!params.device.empty()) {
-        ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str());
-        if (dev) {
-            backend = ggml_backend_dev_init(dev, nullptr);
-            if (!backend) {
-                fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str());
-                return nullptr;
+static std::vector<ggml_backend_dev_t> get_devices(const rpc_server_params & params) {
+    std::vector<ggml_backend_dev_t> devices;
+    if (!params.devices.empty()) {
+        for (auto device : params.devices) {
+            ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str());
+            if (dev) {
+                devices.push_back(dev);
+            } else {
+                fprintf(stderr, "error: unknown device: %s\n", device.c_str());
+                fprintf(stderr, "available devices:\n");
+                for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+                    auto * dev = ggml_backend_dev_get(i);
+                    size_t free, total;
+                    ggml_backend_dev_memory(dev, &free, &total);
+                    printf("  %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
+                }
+                return {};
             }
         }
     }
 
-    if (!backend) {
-        backend = ggml_backend_init_best();
-    }
-
-    if (backend) {
-        fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend));
-
-        // set the number of threads
-        ggml_backend_dev_t dev = ggml_backend_get_device(backend);
-        ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
-        if (reg) {
-            auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
-            if (ggml_backend_set_n_threads_fn) {
-                ggml_backend_set_n_threads_fn(backend, params.n_threads);
+    // Try non-CPU devices first
+    if (devices.empty()) {
+        for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
+                devices.push_back(dev);
             }
         }
     }
 
-    return backend;
-}
+    // If there are no accelerators, fallback to CPU device
+    if (devices.empty()) {
+        ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+        if (dev) {
+            devices.push_back(dev);
+        }
+    }
 
-static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) {
-    ggml_backend_dev_t dev = ggml_backend_get_device(backend);
-    GGML_ASSERT(dev != nullptr);
-    ggml_backend_dev_memory(dev, free_mem, total_mem);
+    return devices;
 }
 
 int main(int argc, char * argv[]) {
@@ -273,18 +287,23 @@ int main(int argc, char * argv[]) {
         fprintf(stderr, "\n");
     }
 
-    ggml_backend_t backend = create_backend(params);
-    if (!backend) {
-        fprintf(stderr, "Failed to create backend\n");
+    auto devices = get_devices(params);
+    if (devices.empty()) {
+        fprintf(stderr, "No devices found\n");
         return 1;
     }
     std::string endpoint = params.host + ":" + std::to_string(params.port);
-    size_t free_mem, total_mem;
-    if (params.backend_mem > 0) {
-        free_mem = params.backend_mem;
-        total_mem = params.backend_mem;
-    } else {
-        get_backend_memory(backend, &free_mem, &total_mem);
+    std::vector<size_t> free_mem, total_mem;
+    for (size_t i = 0; i < devices.size(); i++) {
+        if (i < params.dev_mem.size()) {
+            free_mem.push_back(params.dev_mem[i]);
+            total_mem.push_back(params.dev_mem[i]);
+        } else {
+            size_t free, total;
+            ggml_backend_dev_memory(devices[i], &free, &total);
+            free_mem.push_back(free);
+            total_mem.push_back(total);
+        }
     }
     const char * cache_dir = nullptr;
     std::string cache_dir_str;
@@ -309,8 +328,7 @@ int main(int argc, char * argv[]) {
         return 1;
     }
 
-    start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem);
-
-    ggml_backend_free(backend);
+    start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
+        devices.data(), free_mem.data(), total_mem.data());
     return 0;
 }