]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Track descriptor pools/sets per-context (llama/14109)
authorJeff Bolz <redacted>
Wed, 11 Jun 2025 05:19:25 +0000 (00:19 -0500)
committerGeorgi Gerganov <redacted>
Wed, 18 Jun 2025 09:40:34 +0000 (12:40 +0300)
Use the same descriptor set layout for all pipelines (MAX_PARAMETER_COUNT == 8)
and move it to the vk_device. Move all the descriptor pool and set tracking to
the context - none of it is specific to pipelines anymore. It has a single vector
of pools and vector of sets, and a single counter to track requests and a single
counter to track use.

ggml/src/ggml-vulkan/ggml-vulkan.cpp

index 8ccc73e7422fef436e40a1a66c2529eff9563179..e5200b96d0d8dc15f9d616074c1649fd90ca3f6c 100644 (file)
@@ -78,7 +78,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
 #define VK_VENDOR_ID_INTEL 0x8086
 #define VK_VENDOR_ID_NVIDIA 0x10de
 
-#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32
+#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
 
 #define GGML_VK_MAX_NODES 8192
 
@@ -114,13 +114,11 @@ struct vk_queue {
     bool transfer_only;
 };
 
+#define MAX_PARAMETER_COUNT 8
+
 struct vk_pipeline_struct {
     std::string name;
     vk::ShaderModule shader_module;
-    vk::DescriptorSetLayout dsl;
-    std::vector<vk::DescriptorPool> descriptor_pools;
-    std::vector<vk::DescriptorSet> descriptor_sets;
-    uint32_t descriptor_set_idx;
     vk::PipelineLayout layout;
     vk::Pipeline pipeline;
     uint32_t push_constant_size;
@@ -341,6 +339,8 @@ struct vk_device_struct {
     // set to true to indicate that some shaders need to be compiled after the dryrun
     bool need_compiles {};
 
+    vk::DescriptorSetLayout dsl;
+
     vk_matmul_pipeline pipeline_matmul_f32 {};
     vk_matmul_pipeline pipeline_matmul_f32_f16 {};
     vk_matmul_pipeline pipeline_matmul_bf16 {};
@@ -458,7 +458,6 @@ struct vk_device_struct {
     vk_pipeline pipeline_flash_attn_split_k_reduce;
 
     std::unordered_map<std::string, vk_pipeline_ref> pipelines;
-    std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
 
     std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
 
@@ -498,6 +497,8 @@ struct vk_device_struct {
         }
         pipelines.clear();
 
+        device.destroyDescriptorSetLayout(dsl);
+
         device.destroy();
     }
 };
@@ -930,6 +931,11 @@ struct ggml_backend_vk_context {
     vk_context_ref transfer_ctx;
 
     std::vector<vk_context_ref> tensor_ctxs;
+
+    std::vector<vk::DescriptorPool> descriptor_pools;
+    std::vector<vk::DescriptorSet> descriptor_sets;
+    uint32_t descriptor_set_idx {};
+    uint32_t pipeline_descriptor_set_requirements {};
 };
 
 static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT
@@ -1060,39 +1066,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
                  ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
                  disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
     GGML_ASSERT(parameter_count > 0);
+    GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT);
     GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
 
     vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
     pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
 
-    std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
-    std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
-    for (uint32_t i = 0; i < parameter_count; i++) {
-        dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
-        dsl_binding_flags.push_back({});
-    }
-
-    vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
-
     vk::PushConstantRange pcr(
         vk::ShaderStageFlagBits::eCompute,
         0,
         pipeline->push_constant_size
     );
 
-    vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
-        {},
-        dsl_binding);
-    descriptor_set_layout_create_info.setPNext(&dslbfci);
-    pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
-
-    vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
-    vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
-    pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
-
-    pipeline->descriptor_set_idx = 0;
-
-    vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
+    vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr);
     pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
 
     std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
@@ -1167,15 +1153,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
 
 static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
     VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
-    for (auto& pool : pipeline->descriptor_pools) {
-        device.destroyDescriptorPool(pool);
-    }
-    pipeline->descriptor_pools.clear();
-    pipeline->descriptor_sets.clear();
-    pipeline->descriptor_set_idx = 0;
-
-    device.destroyDescriptorSetLayout(pipeline->dsl);
-
     device.destroyPipelineLayout(pipeline->layout);
 
     device.destroyShaderModule(pipeline->shader_module);
@@ -1183,60 +1160,49 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline)
     device.destroyPipeline(pipeline->pipeline);
 }
 
-static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
+static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) {
     VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
-    device->pipeline_descriptor_set_requirements[pipeline->name] += n;
+    ctx->pipeline_descriptor_set_requirements += n;
     if (!pipeline->compiled) {
         pipeline->needed = true;
-        device->need_compiles = true;
+        ctx->device->need_compiles = true;
     }
 }
 
-static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) {
-    std::lock_guard<std::mutex> guard(device->mutex);
-
-    for (auto& pair : device->pipeline_descriptor_set_requirements) {
-        vk_pipeline pipeline = device->pipelines.at(pair.first).lock();
-        const uint64_t n = pair.second;
-
-        VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")");
+static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) {
 
-        if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
-            // Enough descriptors are available
-            continue;
-        }
+    if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) {
+        // Enough descriptors are available
+        return;
+    }
 
-        uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
-        uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
-        uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
+    vk_device& device = ctx->device;
 
-        while (to_alloc > 0) {
-            const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
-            to_alloc -= alloc_count;
-            pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
+    uint32_t to_alloc = ctx->pipeline_descriptor_set_requirements - ctx->descriptor_sets.size();
+    uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
+    uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
 
-            if (pool_idx >= pipeline->descriptor_pools.size()) {
-                vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
-                vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
-                pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
-            }
+    while (to_alloc > 0) {
+        const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
+        to_alloc -= alloc_count;
+        pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
 
-            std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
-            for (uint32_t i = 0; i < alloc_count; i++) {
-                layouts[i] = pipeline->dsl;
-            }
-            vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data());
-            std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
-            pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
+        if (pool_idx >= ctx->descriptor_pools.size()) {
+            vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
+            vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
+            ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
+        }
 
-            pool_idx++;
+        std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
+        for (uint32_t i = 0; i < alloc_count; i++) {
+            layouts[i] = device->dsl;
         }
-    }
-}
+        vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data());
+        std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+        ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end());
 
-static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
-    VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")");
-    pipeline->descriptor_set_idx = 0;
+        pool_idx++;
+    }
 }
 
 static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) {
@@ -3369,6 +3335,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
             }
         }
 
+
+        std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
+        std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
+        for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) {
+            dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
+            dsl_binding_flags.push_back({});
+        }
+
+        vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
+
+        vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
+            {},
+            dsl_binding);
+        descriptor_set_layout_create_info.setPNext(&dslbfci);
+        device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
+
         ggml_vk_load_shaders(device);
 
         if (!device->single_queue) {
@@ -4154,10 +4136,10 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
         std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
     }
     std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
-    GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
-    GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count);
+    GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
+    GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
 
-    vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
+    vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
     vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
     ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
 
@@ -4964,18 +4946,18 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         }
 
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         if (qx_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
         }
         if (qy_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
         }
         if (quantize_y) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
         }
         if (split_k > 1) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
         }
         return;
     }
@@ -5157,12 +5139,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
 
         // Request descriptor sets
         if (qx_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
         }
         if (qy_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
         }
-        ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
         return;
     }
 
@@ -5295,7 +5277,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
 
     if (dryrun) {
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
+        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
         return;
     }
 
@@ -5384,7 +5366,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
 
     if (dryrun) {
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
         return;
     }
 
@@ -5571,12 +5553,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         }
 
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         if (qx_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
         }
         if (qy_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
         }
         return;
     }
@@ -5765,12 +5747,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
 
         // Request descriptor sets
         if (qx_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
         }
         if (qy_needs_dequant) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
         }
-        ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
         return;
     }
 
@@ -6090,9 +6072,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
 
     if (dryrun) {
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         if (split_k > 1) {
-            ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
+            ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
         }
         return;
     }
@@ -6655,7 +6637,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     }
 
     if (dryrun) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         return;
     }
 
@@ -7036,7 +7018,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
     GGML_ASSERT(pipeline != nullptr);
 
     if (dryrun) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         return;
     }
 
@@ -7175,7 +7157,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
     GGML_ASSERT(pipeline != nullptr);
 
     if (dryrun) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         return;
     }
 
@@ -7853,9 +7835,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         }
     }
 
-    ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
+    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
+        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
         if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
             // Resize buffer
@@ -7870,7 +7852,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         ggml_vk_load_shaders(ctx->device);
     }
 
-    ggml_pipeline_allocate_descriptor_sets(ctx->device);
+    ggml_pipeline_allocate_descriptor_sets(ctx);
 
     vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
     vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -8036,9 +8018,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     ggml_vk_destroy_buffer(d_Y);
     ggml_vk_destroy_buffer(d_D);
 
-    ggml_pipeline_cleanup(p);
-    ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
-
     free(x);
     free(y);
     free(d);
@@ -8116,13 +8095,13 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
     ggml_vk_quantize_data(x, qx, ne, quant);
     ggml_vk_dequantize_data(qx, x_ref, ne, quant);
 
-    ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
+    ggml_pipeline_request_descriptor_sets(ctx, p, 1);
 
     if (ctx->device->need_compiles) {
         ggml_vk_load_shaders(ctx->device);
     }
 
-    ggml_pipeline_allocate_descriptor_sets(ctx->device);
+    ggml_pipeline_allocate_descriptor_sets(ctx);
 
     ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
 
@@ -8216,13 +8195,13 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
 //
 //     vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
 //
-//     ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
+//     ggml_pipeline_request_descriptor_sets(ctx, p, 1);
 //
 //     if (ctx->device->need_compiles) {
 //         ggml_vk_load_shaders(ctx->device);
 //     }
 //
-//     ggml_pipeline_allocate_descriptor_sets(ctx->device);
+//     ggml_pipeline_allocate_descriptor_sets(ctx);
 //
 //     ggml_vk_buffer_write(x_buf, 0, x, x_sz);
 //
@@ -8375,9 +8354,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         // y[i] = i % k;
     }
 
-    ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
+    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it);
+        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
         if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
             // Resize buffer
@@ -8388,14 +8367,14 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         }
     }
     if (mmq) {
-        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
+        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it);
     }
 
     if (ctx->device->need_compiles) {
         ggml_vk_load_shaders(ctx->device);
     }
 
-    ggml_pipeline_allocate_descriptor_sets(ctx->device);
+    ggml_pipeline_allocate_descriptor_sets(ctx);
 
     ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
     ggml_vk_buffer_write(y_buf, 0, y, y_sz);
@@ -8797,7 +8776,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
                 // These operations all go through ggml_vk_op_f32, so short-circuit and
                 // do the only thing needed for the dryrun.
                 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
-                ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
+                ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
                 return false;
             }
         default:
@@ -9189,17 +9168,6 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
     }
     ctx->gc.temp_buffers.clear();
 
-    for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) {
-        vk_pipeline_ref plr = ctx->device->pipelines[dsr.first];
-
-        if (plr.expired()) {
-            continue;
-        }
-
-        vk_pipeline pl = plr.lock();
-        ggml_pipeline_cleanup(pl);
-    }
-
     ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue);
     ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue);
 
@@ -9222,7 +9190,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
 
     ctx->tensor_ctxs.clear();
     ctx->gc.contexts.clear();
-    ctx->device->pipeline_descriptor_set_requirements.clear();
+    ctx->pipeline_descriptor_set_requirements = 0;
+    ctx->descriptor_set_idx = 0;
 }
 
 // Clean up on backend free
@@ -9249,6 +9218,12 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
 
     ctx->device->device.destroyFence(ctx->fence);
     ctx->device->device.destroyFence(ctx->almost_ready_fence);
+
+    for (auto& pool : ctx->descriptor_pools) {
+        ctx->device->device.destroyDescriptorPool(pool);
+    }
+    ctx->descriptor_pools.clear();
+    ctx->descriptor_sets.clear();
 }
 
 static int ggml_vk_get_device_count() {
@@ -9622,7 +9597,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         ggml_vk_load_shaders(ctx->device);
     }
     ggml_vk_preallocate_buffers(ctx);
-    ggml_pipeline_allocate_descriptor_sets(ctx->device);
+    ggml_pipeline_allocate_descriptor_sets(ctx);
 
     int last_node = cgraph->n_nodes - 1;