bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
+ bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
return true;
}
+bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
+ uint32_t dev_id = request.device;
+ if (dev_id >= backends.size()) {
+ return false;
+ }
+ size_t free, total;
+ ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
+ ggml_backend_dev_memory(dev, &free, &total);
+ response.free_mem = free;
+ response.total_mem = total;
+ LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
+ return true;
+}
+
rpc_server::~rpc_server() {
for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer);
}
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
- sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
+ sockfd_t sockfd) {
rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
- auto dev_id = request.device;
- if (dev_id >= backends.size()) {
+ rpc_msg_get_device_memory_rsp response;
+ if (!server.get_device_memory(request, response)) {
return;
}
- rpc_msg_get_device_memory_rsp response;
- response.free_mem = free_mem[dev_id];
- response.total_mem = total_mem[dev_id];
- LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
- response.free_mem, response.total_mem);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
}
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
- size_t n_threads, size_t n_devices,
- ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
- if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
+ size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
+ if (n_devices == 0 || devices == nullptr) {
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
return;
}
std::vector<ggml_backend_t> backends;
- std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
- std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
printf("Devices:\n");
for (size_t i = 0; i < n_devices; i++) {
auto dev = devices[i];
+ size_t free, total;
+ ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
- total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
+ total / 1024 / 1024, free / 1024 / 1024);
auto backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
}
printf("Accepted client connection\n");
fflush(stdout);
- rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
+ rpc_serve_client(backends, cache_dir, client_socket->fd);
printf("Client connection closed\n");
fflush(stdout);
}
bool use_cache = false;
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
std::vector<std::string> devices;
- std::vector<size_t> dev_mem;
};
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
- fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, "\n");
}
}
} else if (arg == "-c" || arg == "--cache") {
params.use_cache = true;
- } else if (arg == "-m" || arg == "--mem") {
- if (++i >= argc) {
- return false;
- }
- const std::regex regex{ R"([,/]+)" };
- std::string mem_str = argv[i];
- std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
- std::sregex_token_iterator end;
- for ( ; iter != end; ++iter) {
- try {
- size_t mem = std::stoul(*iter) * 1024 * 1024;
- params.dev_mem.push_back(mem);
- } catch (const std::exception & ) {
- fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
- return false;
- }
- }
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, params);
exit(0);
return 1;
}
std::string endpoint = params.host + ":" + std::to_string(params.port);
- std::vector<size_t> free_mem, total_mem;
- for (size_t i = 0; i < devices.size(); i++) {
- if (i < params.dev_mem.size()) {
- free_mem.push_back(params.dev_mem[i]);
- total_mem.push_back(params.dev_mem[i]);
- } else {
- size_t free, total;
- ggml_backend_dev_memory(devices[i], &free, &total);
- free_mem.push_back(free);
- total_mem.push_back(total);
- }
- }
const char * cache_dir = nullptr;
std::string cache_dir_str;
if (params.use_cache) {
return 1;
}
- start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
- devices.data(), free_mem.data(), total_mem.data());
+ start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data());
return 0;
}