]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
rpc : track allocated buffers (llama/7411)
authorRadoslav Gerganov <redacted>
Mon, 20 May 2024 13:36:55 +0000 (16:36 +0300)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
* rpc : track allocated buffers

ref: #7407

* rpc : pack rpc_tensor tightly

src/ggml-rpc.cpp

index 4a9bfa52d87b124304613318b0f3e2562c512e4a..cc1d3ace1ddac012b87d2ac31aac81111f4d15b6 100644 (file)
@@ -56,6 +56,7 @@ struct socket_t {
 };
 
 // ggml_tensor is serialized into rpc_tensor
+#pragma pack(push, 1)
 struct rpc_tensor {
     uint64_t id;
     uint32_t type;
@@ -71,6 +72,7 @@ struct rpc_tensor {
     uint64_t data;
     char name[GGML_MAX_NAME];
 };
+#pragma pack(pop)
 
 // RPC commands
 enum rpc_cmd {
@@ -340,23 +342,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
     return result;
 }
 
-static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
-    ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
-        tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
-    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
-        result->nb[i] = tensor->nb[i];
-    }
-    result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
-    result->op = (ggml_op) tensor->op;
-    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
-        result->op_params[i] = tensor->op_params[i];
-    }
-    result->flags = tensor->flags;
-    result->data = reinterpret_cast<void *>(tensor->data);
-    ggml_set_name(result, tensor->name);
-    return result;
-}
-
 GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
     UNUSED(buffer);
     if (ggml_is_quantized(tensor->type)) {
@@ -465,13 +450,15 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
     memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
     size_t remote_size;
     memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
-
-    ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
-        ggml_backend_rpc_buffer_interface,
-        new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
-        remote_size);
-
-    return buffer;
+    if (remote_ptr != 0) {
+        ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
+            ggml_backend_rpc_buffer_interface,
+            new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
+            remote_size);
+        return buffer;
+    } else {
+        return nullptr;
+    }
 }
 
 static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
@@ -658,7 +645,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
         }
     }
 #endif
-    GGML_PRINT_DEBUG("Connecting to %s\n", endpoint);
+    fprintf(stderr, "Connecting to %s\n", endpoint);
     std::string host;
     int port;
     if (!parse_endpoint(endpoint, host, port)) {
@@ -731,22 +718,61 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
 
 // RPC server-side implementation
 
-static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+class rpc_server {
+public:
+    rpc_server(ggml_backend_t backend) : backend(backend) {}
+    ~rpc_server();
+
+    bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+    void get_alignment(std::vector<uint8_t> & output);
+    void get_max_size(std::vector<uint8_t> & output);
+    bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+    bool free_buffer(const std::vector<uint8_t> & input);
+    bool buffer_clear(const std::vector<uint8_t> & input);
+    bool set_tensor(const std::vector<uint8_t> & input);
+    bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+    bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+    bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
+
+private:
+    ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
+    ggml_tensor * create_node(uint64_t id,
+                              struct ggml_context * ctx,
+                              const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
+                              std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
+
+
+    ggml_backend_t backend;
+    std::unordered_set<ggml_backend_buffer_t> buffers;
+};
+
+bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
     // input serialization format: | size (8 bytes) |
+    if (input.size() != sizeof(uint64_t)) {
+        return false;
+    }
     uint64_t size;
     memcpy(&size, input.data(), sizeof(size));
     ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
     ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
-    uint64_t remote_ptr = reinterpret_cast<uint64_t>(buffer);
-    uint64_t remote_size = buffer->size;
-    GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
+    uint64_t remote_ptr = 0;
+    uint64_t remote_size = 0;
+    if (buffer != nullptr) {
+        remote_ptr = reinterpret_cast<uint64_t>(buffer);
+        remote_size = buffer->size;
+        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
+        buffers.insert(buffer);
+    } else {
+        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
+    }
     // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
     output.resize(2*sizeof(uint64_t), 0);
     memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
     memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
+    return true;
 }
 
-static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
+void rpc_server::get_alignment(std::vector<uint8_t> & output) {
     ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
     size_t alignment = ggml_backend_buft_get_alignment(buft);
     GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
@@ -755,7 +781,7 @@ static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & out
     memcpy(output.data(), &alignment, sizeof(alignment));
 }
 
-static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
+void rpc_server::get_max_size(std::vector<uint8_t> & output) {
     ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
     size_t max_size = ggml_backend_buft_get_max_size(buft);
     GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
@@ -764,41 +790,90 @@ static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & outp
     memcpy(output.data(), &max_size, sizeof(max_size));
 }
 
-static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
     // input serialization format: | remote_ptr (8 bytes) |
+    if (input.size() != sizeof(uint64_t)) {
+        return false;
+    }
     uint64_t remote_ptr;
     memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
     GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
     ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
+    if (buffers.find(buffer) == buffers.end()) {
+        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+        return false;
+    }
     void * base = ggml_backend_buffer_get_base(buffer);
     // output serialization format: | base_ptr (8 bytes) |
     uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
     output.resize(sizeof(uint64_t), 0);
     memcpy(output.data(), &base_ptr, sizeof(base_ptr));
+    return true;
 }
 
-static void rpc_free_buffer(const std::vector<uint8_t> & input) {
+bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
     // input serialization format: | remote_ptr (8 bytes) |
+    if (input.size() != sizeof(uint64_t)) {
+        return false;
+    }
     uint64_t remote_ptr;
     memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
     GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
     ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
+    if (buffers.find(buffer) == buffers.end()) {
+        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+        return false;
+    }
     ggml_backend_buffer_free(buffer);
+    buffers.erase(buffer);
+    return true;
 }
 
-static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
+bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
     // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
+    if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
+        return false;
+    }
     uint64_t remote_ptr;
     memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
     uint8_t value;
     memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
     GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
     ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
+    if (buffers.find(buffer) == buffers.end()) {
+        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+        return false;
+    }
     ggml_backend_buffer_clear(buffer, value);
+    return true;
 }
 
-static void rpc_set_tensor(const std::vector<uint8_t> & input) {
+ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
+    ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
+        tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
+        result->nb[i] = tensor->nb[i];
+    }
+    result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
+    if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
+        return nullptr;
+    }
+    result->op = (ggml_op) tensor->op;
+    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
+        result->op_params[i] = tensor->op_params[i];
+    }
+    result->flags = tensor->flags;
+    result->data = reinterpret_cast<void *>(tensor->data);
+    ggml_set_name(result, tensor->name);
+    return result;
+}
+
+
+bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
     // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
+    if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
+        return false;
+    }
     const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
     uint64_t offset;
     memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
@@ -811,14 +886,23 @@ static void rpc_set_tensor(const std::vector<uint8_t> & input) {
     };
     struct ggml_context * ctx = ggml_init(params);
     ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
+    if (tensor == nullptr) {
+        GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
+        ggml_free(ctx);
+        return false;
+    }
     GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
     const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
     ggml_backend_tensor_set(tensor, data, offset, size);
     ggml_free(ctx);
+    return true;
 }
 
-static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
     // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
+    if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
+        return false;
+    }
     const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
     uint64_t offset;
     memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
@@ -832,15 +916,24 @@ static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8
     };
     struct ggml_context * ctx = ggml_init(params);
     ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
+    if (tensor == nullptr) {
+        GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
+        ggml_free(ctx);
+        return false;
+    }
     GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
     // output serialization format: | data (size bytes) |
     output.resize(size, 0);
     ggml_backend_tensor_get(tensor, output.data(), offset, size);
     ggml_free(ctx);
+    return true;
 }
 
-static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
     // serialization format: | rpc_tensor src | rpc_tensor dst |
+    if (input.size() != 2*sizeof(rpc_tensor)) {
+        return false;
+    }
     const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
     const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
 
@@ -852,18 +945,24 @@ static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint
     struct ggml_context * ctx = ggml_init(params);
     ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
     ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
+    if (src == nullptr || dst == nullptr) {
+        GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
+        ggml_free(ctx);
+        return false;
+    }
     GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
     bool result = ggml_backend_buffer_copy_tensor(src, dst);
     // output serialization format: | result (1 byte) |
     output.resize(1, 0);
     output[0] = result;
     ggml_free(ctx);
+    return true;
 }
 
-static struct ggml_tensor * create_node(uint64_t id,
-                                        struct ggml_context * ctx,
-                                        const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
-                                        std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
+ggml_tensor * rpc_server::create_node(uint64_t id,
+                                      struct ggml_context * ctx,
+                                      const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
+                                      std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
     if (id == 0) {
         return nullptr;
     }
@@ -872,6 +971,9 @@ static struct ggml_tensor * create_node(uint64_t id,
     }
     const rpc_tensor * tensor = tensor_ptrs.at(id);
     struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
+    if (result == nullptr) {
+        return nullptr;
+    }
     tensor_map[id] = result;
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
@@ -881,14 +983,23 @@ static struct ggml_tensor * create_node(uint64_t id,
     return result;
 }
 
-static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
+bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
     // serialization format:
     // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
+    if (input.size() < sizeof(uint32_t)) {
+        return false;
+    }
     uint32_t n_nodes;
     memcpy(&n_nodes, input.data(), sizeof(n_nodes));
+    if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
+        return false;
+    }
     const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
     uint32_t n_tensors;
     memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
+    if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
+        return false;
+    }
     const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
     GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
 
@@ -914,9 +1025,17 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
     output.resize(1, 0);
     output[0] = status;
     ggml_free(ctx);
+    return true;
+}
+
+rpc_server::~rpc_server() {
+    for (auto buffer : buffers) {
+        ggml_backend_buffer_free(buffer);
+    }
 }
 
 static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
+    rpc_server server(backend);
     while (true) {
         uint8_t cmd;
         if (!recv_data(sockfd, &cmd, 1)) {
@@ -932,45 +1051,46 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
         if (!recv_data(sockfd, input.data(), input_size)) {
             break;
         }
+        bool ok = true;
         switch (cmd) {
             case ALLOC_BUFFER: {
-                rpc_alloc_buffer(backend, input, output);
+                ok = server.alloc_buffer(input, output);
                 break;
             }
             case GET_ALIGNMENT: {
-                rpc_get_alignment(backend, output);
+                server.get_alignment(output);
                 break;
             }
             case GET_MAX_SIZE: {
-                rpc_get_max_size(backend, output);
+                server.get_max_size(output);
                 break;
             }
             case BUFFER_GET_BASE: {
-                rpc_buffer_get_base(input, output);
+                ok = server.buffer_get_base(input, output);
                 break;
             }
             case FREE_BUFFER: {
-                rpc_free_buffer(input);
+                ok = server.free_buffer(input);
                 break;
             }
             case BUFFER_CLEAR: {
-                rpc_buffer_clear(input);
+                ok = server.buffer_clear(input);
                 break;
             }
             case SET_TENSOR: {
-                rpc_set_tensor(input);
+                ok = server.set_tensor(input);
                 break;
             }
             case GET_TENSOR: {
-                rpc_get_tensor(input, output);
+                ok = server.get_tensor(input, output);
                 break;
             }
             case COPY_TENSOR: {
-                rpc_copy_tensor(input, output);
+                ok = server.copy_tensor(input, output);
                 break;
             }
             case GRAPH_COMPUTE: {
-                rpc_graph_compute(backend, input, output);
+                ok = server.graph_compute(input, output);
                 break;
             }
             case GET_DEVICE_MEMORY: {
@@ -982,9 +1102,12 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
             }
             default: {
                 fprintf(stderr, "Unknown command: %d\n", cmd);
-                return;
+                ok = false;
             }
         }
+        if (!ok) {
+            break;
+        }
         uint64_t output_size = output.size();
         if (!send_data(sockfd, &output_size, sizeof(output_size))) {
             break;