]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan : multithread pipeline creation (ggml/963)
authorJeff Bolz <redacted>
Sun, 29 Sep 2024 16:50:17 +0000 (11:50 -0500)
committerGeorgi Gerganov <redacted>
Sun, 29 Sep 2024 18:15:37 +0000 (21:15 +0300)
ggml/src/ggml-vulkan.cpp

index 70b7291540b2bc3e34161f1a06411cba5bc76047..c677a27287cc0ca5d1abb6f2fcedb9dd117e9b9c 100644 (file)
@@ -20,6 +20,8 @@
 #include <unordered_map>
 #include <memory>
 #include <mutex>
+#include <future>
+#include <thread>
 
 #include "ggml-impl.h"
 #include "ggml-backend-impl.h"
@@ -607,13 +609,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
 
 GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
 
-static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
+// variables to track number of compiles in progress
+static uint32_t compile_count = 0;
+static std::mutex compile_count_mutex;
+static std::condition_variable compile_count_cond;
+
+static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
     VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
     GGML_ASSERT(parameter_count > 0);
     GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
 
-    std::lock_guard<std::mutex> guard(device->mutex);
-
     pipeline = std::make_shared<vk_pipeline_struct>();
     pipeline->name = name;
     pipeline->parameter_count = parameter_count;
@@ -681,7 +686,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co
         pipeline->layout);
     pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
 
-    device->pipelines.insert({ pipeline->name, pipeline });
+    {
+        std::lock_guard<std::mutex> guard(device->mutex);
+        device->pipelines.insert({ pipeline->name, pipeline });
+    }
+
+    {
+        std::lock_guard<std::mutex> guard(compile_count_mutex);
+        assert(compile_count > 0);
+        compile_count--;
+    }
+    compile_count_cond.notify_all();
 }
 
 static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -1194,6 +1209,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
     device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
     device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
 
+    std::vector<std::future<void>> compiles;
+    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
+        {
+            // wait until fewer than N compiles are in progress
+            uint32_t N = std::max(1u, std::thread::hardware_concurrency());
+            std::unique_lock<std::mutex> guard(compile_count_mutex);
+            while (compile_count >= N) {
+                compile_count_cond.wait(guard);
+            }
+            compile_count++;
+        }
+        compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
+    };
+
     if (device->fp16) {
         ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
         ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
@@ -1743,6 +1772,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
+
+    for (auto &c : compiles) {
+        c.wait();
+    }
 }
 
 static vk_device ggml_vk_get_device(size_t idx) {