#include <unordered_map>
#include <memory>
#include <mutex>
+#include <future>
+#include <thread>
#include "ggml-impl.h"
#include "ggml-backend-impl.h"
GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
-static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
+// variables to track number of compiles in progress
+static uint32_t compile_count = 0;
+static std::mutex compile_count_mutex;
+static std::condition_variable compile_count_cond;
+
+static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
GGML_ASSERT(parameter_count > 0);
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
- std::lock_guard<std::mutex> guard(device->mutex);
-
pipeline = std::make_shared<vk_pipeline_struct>();
pipeline->name = name;
pipeline->parameter_count = parameter_count;
pipeline->layout);
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
- device->pipelines.insert({ pipeline->name, pipeline });
+ {
+ std::lock_guard<std::mutex> guard(device->mutex);
+ device->pipelines.insert({ pipeline->name, pipeline });
+ }
+
+ {
+ std::lock_guard<std::mutex> guard(compile_count_mutex);
+ assert(compile_count > 0);
+ compile_count--;
+ }
+ compile_count_cond.notify_all();
}
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
+ std::vector<std::future<void>> compiles;
+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
+ {
+ // wait until fewer than N compiles are in progress
+ uint32_t N = std::max(1u, std::thread::hardware_concurrency());
+ std::unique_lock<std::mutex> guard(compile_count_mutex);
+ while (compile_count >= N) {
+ compile_count_cond.wait(guard);
+ }
+ compile_count++;
+ }
+ compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
+ };
+
if (device->fp16) {
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
+
+ for (auto &c : compiles) {
+ c.wait();
+ }
}
static vk_device ggml_vk_get_device(size_t idx) {