# include <unistd.h>
#endif
#include <cstring>
+#include <fstream>
+#include <filesystem>
+
+namespace fs = std::filesystem;
#ifdef _WIN32
typedef SOCKET sockfd_t;
RPC_CMD_FREE_BUFFER,
RPC_CMD_BUFFER_CLEAR,
RPC_CMD_SET_TENSOR,
+ RPC_CMD_SET_TENSOR_HASH,
RPC_CMD_GET_TENSOR,
RPC_CMD_COPY_TENSOR,
RPC_CMD_GRAPH_COMPUTE,
RPC_CMD_COUNT,
};
+// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
+const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
+
struct rpc_msg_get_alloc_size_req {
rpc_tensor tensor;
};
uint8_t value;
};
+struct rpc_msg_set_tensor_hash_rsp {
+ uint8_t result;
+};
+
struct rpc_msg_get_tensor_req {
rpc_tensor tensor;
uint64_t offset;
// RPC helper functions
+// Computes FNV-1a hash of the data
+static uint64_t fnv_hash(const uint8_t * data, size_t len) {
+ const uint64_t fnv_prime = 0x100000001b3ULL;
+ uint64_t hash = 0xcbf29ce484222325ULL;
+
+ for (size_t i = 0; i < len; ++i) {
+ hash ^= data[i];
+ hash *= fnv_prime;
+ }
+ return hash;
+}
+
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
#ifdef _WIN32
if (fd == INVALID_SOCKET) {
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
+ rpc_tensor rpc_tensor = serialize_tensor(tensor);
+ if (size > HASH_THRESHOLD) {
+ // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
+ size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
+ std::vector<uint8_t> input(input_size, 0);
+ uint64_t hash = fnv_hash((const uint8_t*)data, size);
+ memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
+ memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
+ memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
+ rpc_msg_set_tensor_hash_rsp response;
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
+ GGML_ASSERT(status);
+ if (response.result) {
+ // the server has the same data, no need to send it
+ return;
+ }
+ }
+ // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
std::vector<uint8_t> input(input_size, 0);
- rpc_tensor rpc_tensor = serialize_tensor(tensor);
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
class rpc_server {
public:
- rpc_server(ggml_backend_t backend) : backend(backend) {}
+ rpc_server(ggml_backend_t backend, const char * cache_dir)
+ : backend(backend), cache_dir(cache_dir) {
+ }
~rpc_server();
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
bool free_buffer(const rpc_msg_free_buffer_req & request);
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
bool set_tensor(const std::vector<uint8_t> & input);
+ bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
private:
+ bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
ggml_tensor * create_node(uint64_t id,
struct ggml_context * ctx,
ggml_backend_t backend;
+ const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
};
}
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
+ if (cache_dir && size > HASH_THRESHOLD) {
+ uint64_t hash = fnv_hash((const uint8_t*)data, size);
+ char hash_str[17];
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
+ // save to cache_dir/hash_str
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
+ std::ofstream ofs(cache_file, std::ios::binary);
+ ofs.write((const char *)data, size);
+ printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
+ }
ggml_backend_tensor_set(tensor, data, offset, size);
ggml_free(ctx);
return true;
}
+bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
+ if (!cache_dir) {
+ return false;
+ }
+ char hash_str[17];
+ snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
+ fs::path cache_file = fs::path(cache_dir) / hash_str;
+ if (!fs::exists(cache_file)) {
+ return false;
+ }
+ std::ifstream ifs(cache_file, std::ios::binary);
+ ifs.seekg(0, std::ios::end);
+ size_t size = ifs.tellg();
+ ifs.seekg(0, std::ios::beg);
+ data.resize(size);
+ ifs.read((char *)data.data(), size);
+ return true;
+}
+
+bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
+{
+ // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
+ if (input.size() != sizeof(rpc_tensor) + 16) {
+ return false;
+ }
+ const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
+ uint64_t offset;
+ memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
+ const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
+ std::vector<uint8_t> cached_file;
+ if (!get_cached_file(*hash, cached_file)) {
+ response.result = 0;
+ return true;
+ }
+ size_t size = cached_file.size();
+ struct ggml_init_params params {
+ /*.mem_size =*/ ggml_tensor_overhead(),
+ /*.mem_buffer =*/ NULL,
+ /*.no_alloc =*/ true,
+ };
+ struct ggml_context * ctx = ggml_init(params);
+ ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
+ if (tensor == nullptr) {
+ GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
+ ggml_free(ctx);
+ return false;
+ }
+ GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
+
+ // sanitize tensor->data
+ {
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
+
+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
+ GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
+ }
+ }
+ ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
+ response.result = 1;
+ ggml_free(ctx);
+ return true;
+}
+
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
}
}
-static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
- rpc_server server(backend);
+static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
+ sockfd_t sockfd, size_t free_mem, size_t total_mem) {
+ rpc_server server(backend, cache_dir);
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
}
break;
}
+ case RPC_CMD_SET_TENSOR_HASH: {
+ std::vector<uint8_t> input;
+ if (!recv_msg(sockfd, input)) {
+ return;
+ }
+ rpc_msg_set_tensor_hash_rsp response;
+ if (!server.set_tensor_hash(input, response)) {
+ return;
+ }
+ if (!send_msg(sockfd, &response, sizeof(response))) {
+ return;
+ }
+ break;
+ }
case RPC_CMD_INIT_TENSOR: {
rpc_msg_init_tensor_req request;
if (!recv_msg(sockfd, &request,sizeof(request))) {
}
}
-void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
+void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
+ const char * cache_dir,
+ size_t free_mem, size_t total_mem) {
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
}
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
fflush(stdout);
- rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
+ rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n");
fflush(stdout);
}