RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_INIT_TENSOR,
RPC_CMD_GET_ALLOC_SIZE,
+ RPC_CMD_HELLO,
RPC_CMD_COUNT,
};
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
+struct rpc_msg_hello_rsp {
+ uint8_t major;
+ uint8_t minor;
+ uint8_t patch;
+};
+
struct rpc_msg_get_alloc_size_req {
rpc_tensor tensor;
};
// RPC client-side implementation
+static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
+ rpc_msg_hello_rsp response;
+ bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
+ GGML_ASSERT(status);
+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
+ fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
+ return false;
+ }
+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
+ fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
+ }
+ return true;
+}
+
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (sock == nullptr) {
return nullptr;
}
+ if (!check_server_version(sock)) {
+ return nullptr;
+ }
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
sockets[endpoint] = sock;
return sock;
}
~rpc_server();
+ void hello(rpc_msg_hello_rsp & response);
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
void get_alignment(rpc_msg_get_alignment_rsp & response);
void get_max_size(rpc_msg_get_max_size_rsp & response);
std::unordered_set<ggml_backend_buffer_t> buffers;
};
+void rpc_server::hello(rpc_msg_hello_rsp & response) {
+ response.major = RPC_PROTO_MAJOR_VERSION;
+ response.minor = RPC_PROTO_MINOR_VERSION;
+ response.patch = RPC_PROTO_PATCH_VERSION;
+ GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
+}
+
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
ggml_backend_buffer_type_t buft;
struct ggml_init_params params {
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
rpc_server server(backend, cache_dir);
+ uint8_t cmd;
+ if (!recv_data(sockfd, &cmd, 1)) {
+ return;
+ }
+ // the first command sent by the client must be HELLO
+ if (cmd != RPC_CMD_HELLO) {
+ fprintf(stderr, "Expected HELLO command, update client\n");
+ return;
+ }
+ if (!recv_msg(sockfd, nullptr, 0)) {
+ return;
+ }
+ rpc_msg_hello_rsp response;
+ server.hello(response);
+ if (!send_msg(sockfd, &response, sizeof(response))) {
+ return;
+ }
while (true) {
- uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
break;
}
switch (cmd) {
+ case RPC_CMD_HELLO: {
+ // HELLO command is handled above
+ return;
+ }
case RPC_CMD_ALLOC_BUFFER: {
rpc_msg_alloc_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {