// RPC commands
enum rpc_cmd {
- ALLOC_BUFFER = 0,
- GET_ALIGNMENT,
- GET_MAX_SIZE,
- BUFFER_GET_BASE,
- FREE_BUFFER,
- BUFFER_CLEAR,
- SET_TENSOR,
- GET_TENSOR,
- COPY_TENSOR,
- GRAPH_COMPUTE,
- GET_DEVICE_MEMORY,
+ RPC_CMD_ALLOC_BUFFER = 0,
+ RPC_CMD_GET_ALIGNMENT,
+ RPC_CMD_GET_MAX_SIZE,
+ RPC_CMD_BUFFER_GET_BASE,
+ RPC_CMD_FREE_BUFFER,
+ RPC_CMD_BUFFER_CLEAR,
+ RPC_CMD_SET_TENSOR,
+ RPC_CMD_GET_TENSOR,
+ RPC_CMD_COPY_TENSOR,
+ RPC_CMD_GRAPH_COMPUTE,
+ RPC_CMD_GET_DEVICE_MEMORY,
+ RPC_CMD_COUNT,
};
// RPC data structures
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.empty());
delete ctx;
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | base_ptr (8 bytes) |
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
GGML_ASSERT(status);
}
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == size);
// output serialization format: | data (size bytes) |
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
GGML_ASSERT(status);
// output serialization format: | result (1 byte) |
GGML_ASSERT(output.size() == 1);
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
GGML_ASSERT(status);
}
memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output;
auto sock = get_socket(buft_ctx->endpoint);
- bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
+ bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | alignment (8 bytes) |
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | max_size (8 bytes) |
serialize_graph(cgraph, input);
std::vector<uint8_t> output;
auto sock = get_socket(rpc_ctx->endpoint);
- bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1);
return (enum ggml_status)output[0];
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
- bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | free (8 bytes) | total (8 bytes) |
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
+ if (cmd >= RPC_CMD_COUNT) {
+ // fail fast if the command is invalid
+ fprintf(stderr, "Unknown command: %d\n", cmd);
+ break;
+ }
std::vector<uint8_t> input;
std::vector<uint8_t> output;
uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break;
}
- input.resize(input_size);
+ try {
+ input.resize(input_size);
+ } catch (const std::bad_alloc & e) {
+ fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
+ break;
+ }
if (!recv_data(sockfd, input.data(), input_size)) {
break;
}
bool ok = true;
switch (cmd) {
- case ALLOC_BUFFER: {
+ case RPC_CMD_ALLOC_BUFFER: {
ok = server.alloc_buffer(input, output);
break;
}
- case GET_ALIGNMENT: {
+ case RPC_CMD_GET_ALIGNMENT: {
server.get_alignment(output);
break;
}
- case GET_MAX_SIZE: {
+ case RPC_CMD_GET_MAX_SIZE: {
server.get_max_size(output);
break;
}
- case BUFFER_GET_BASE: {
+ case RPC_CMD_BUFFER_GET_BASE: {
ok = server.buffer_get_base(input, output);
break;
}
- case FREE_BUFFER: {
+ case RPC_CMD_FREE_BUFFER: {
ok = server.free_buffer(input);
break;
}
- case BUFFER_CLEAR: {
+ case RPC_CMD_BUFFER_CLEAR: {
ok = server.buffer_clear(input);
break;
}
- case SET_TENSOR: {
+ case RPC_CMD_SET_TENSOR: {
ok = server.set_tensor(input);
break;
}
- case GET_TENSOR: {
+ case RPC_CMD_GET_TENSOR: {
ok = server.get_tensor(input, output);
break;
}
- case COPY_TENSOR: {
+ case RPC_CMD_COPY_TENSOR: {
ok = server.copy_tensor(input, output);
break;
}
- case GRAPH_COMPUTE: {
+ case RPC_CMD_GRAPH_COMPUTE: {
ok = server.graph_compute(input, output);
break;
}
- case GET_DEVICE_MEMORY: {
+ case RPC_CMD_GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) |
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &free_mem, sizeof(free_mem));
return;
}
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
+ fflush(stdout);
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n");
+ fflush(stdout);
}
#ifdef _WIN32
WSACleanup();