]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Vulkan Improvements (llama/5835)
author0cc4m <redacted>
Tue, 5 Mar 2024 12:33:42 +0000 (13:33 +0100)
committerGeorgi Gerganov <redacted>
Fri, 8 Mar 2024 09:38:33 +0000 (11:38 +0200)
* Improve dequant shaders, add fast q4_0 dequant

* Optimize dmmv non-kquants for GCN

Remove unnecessary SPIR-V shader duplication

* Fix q4_0 dequant dispatch sizes

Fix backend free bug

* Optimize dequant shaders for q4_1, q5_0, q5_1 and q8_0

* Add unary and binary op shader templates

* Fix Vulkan check results

* Enable non-contiguous support for simple ops

* Add argsort

Basic q4_0 mmq shader and unit test

* Speed up q4_0 dequant code, enable mmq for q4_0

* Rework matmul pipeline selection

* Add soft_max alibi support

* Add q4_1, q5_0, q5_1 and q8_0 dequant mat mat mul shaders

* Add environment variable GGML_VK_FORCE_MAX_ALLOCATION_SIZE to limit max buffer size

Rename GGML_VULKAN_DISABLE_F16 to GGML_VK_DISABLE_F16 for consistency

ggml-vulkan.cpp
ggml-vulkan.h

index bc316c3f3944d9f4f6e066194ec053f173d4a11b..5a1b3f477618116af0d25ee957f69a5c4a9ce6d4 100644 (file)
@@ -69,6 +69,33 @@ struct vk_queue {
     vk::PipelineStageFlags stage_flags;
 };
 
+struct vk_pipeline_struct {
+    std::string name;
+    vk::ShaderModule shader_module;
+    vk::DescriptorSetLayout dsl;
+    std::vector<vk::DescriptorPool> descriptor_pools;
+    std::vector<vk::DescriptorSet> descriptor_sets;
+    uint32_t descriptor_set_idx;
+    vk::PipelineLayout layout;
+    vk::Pipeline pipeline;
+    uint32_t push_constant_size;
+    uint32_t parameter_count;
+    std::array<uint32_t, 3> wg_denoms;
+    uint32_t align;
+};
+
+typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
+typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
+
+static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
+
+struct vk_matmul_pipeline_struct {
+    vk_pipeline l, m, s;
+    vk_pipeline a_l, a_m, a_s;
+};
+
+typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
+
 struct vk_device {
     vk::PhysicalDevice physical_device;
     vk::PhysicalDeviceProperties properties;
@@ -84,10 +111,61 @@ struct vk_device {
     uint32_t subgroup_size;
     bool uma;
 
+    bool initialized;
+    size_t idx;
+
+    vk_matmul_pipeline pipeline_matmul_f32;
+    vk_matmul_pipeline pipeline_matmul_f16;
+    vk_matmul_pipeline pipeline_matmul_f16_f32;
+    vk_pipeline pipeline_matmul_split_k_reduce;
+
+    vk_matmul_pipeline pipeline_dequant_mul_mat_mat[VK_NUM_TYPES];
+
+    vk_pipeline pipeline_dequant[VK_NUM_TYPES];
+    vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
+
+    vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
+    vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
+    vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
+    vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
+    vk_pipeline pipeline_mul_f32;
+    vk_pipeline pipeline_add_f32;
+    vk_pipeline pipeline_scale_f32;
+    vk_pipeline pipeline_sqr_f32;
+    vk_pipeline pipeline_clamp_f32;
+    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
+    vk_pipeline pipeline_norm_f32;
+    vk_pipeline pipeline_rms_norm_f32;
+    vk_pipeline pipeline_gelu_f32;
+    vk_pipeline pipeline_silu_f32;
+    vk_pipeline pipeline_relu_f32;
+    vk_pipeline pipeline_diag_mask_inf_f32;
+    vk_pipeline pipeline_soft_max_f32;
+    vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
+    vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
+    vk_pipeline pipeline_argsort_f32;
+
+    std::vector<vk_pipeline_ref> pipelines;
+
     ~vk_device() {
 #ifdef GGML_VULKAN_DEBUG
     std::cerr << "destroy device " << name << std::endl;
 #endif
+        device.destroyCommandPool(compute_queue.pool);
+        if (!single_queue) {
+            device.destroyCommandPool(transfer_queue.pool);
+        }
+
+        for (auto& pipeline : pipelines) {
+            if (pipeline.expired()) {
+                continue;
+            }
+
+            vk_pipeline pl = pipeline.lock();
+            ggml_vk_destroy_pipeline(device, pl);
+        }
+        pipelines.clear();
+
         device.destroy();
     }
 };
@@ -125,21 +203,6 @@ struct vk_subbuffer {
     uint64_t size;
 };
 
-struct vk_pipeline {
-    std::string name;
-    vk::ShaderModule shader_module;
-    vk::DescriptorSetLayout dsl;
-    std::vector<vk::DescriptorPool> descriptor_pools;
-    std::vector<vk::DescriptorSet> descriptor_sets;
-    uint32_t descriptor_set_idx;
-    vk::PipelineLayout layout;
-    vk::Pipeline pipeline;
-    uint32_t push_constant_size;
-    uint32_t parameter_count;
-    std::array<uint32_t, 3> wg_denoms;
-    uint32_t align;
-};
-
 struct vk_semaphore {
     vk::Semaphore s;
     uint64_t value;
@@ -160,11 +223,21 @@ struct vk_op_push_constants {
     float param2;
 };
 
-struct vk_op_cpy_push_constants {
+struct vk_op_unary_push_constants {
     uint32_t ne;
-    uint32_t ne00; uint32_t ne01; uint32_t nb00; uint32_t nb01; uint32_t nb02;
-    uint32_t ne10; uint32_t ne11; uint32_t nb10; uint32_t nb11; uint32_t nb12;
+    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
+    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
     uint32_t d_offset;
+    float param1; float param2;
+};
+
+struct vk_op_binary_push_constants {
+    uint32_t ne;
+    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
+    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
+    uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
+    uint32_t d_offset;
+    float param1; float param2;
 };
 
 struct vk_op_diag_mask_push_constants {
@@ -196,6 +269,22 @@ struct vk_op_rope_neox_push_constants {
     float inv_ndims;
 };
 
+struct vk_op_soft_max_push_constants {
+    uint32_t KX;
+    uint32_t KY;
+    uint32_t KZ;
+    float scale;
+    float max_bias;
+    float m0;
+    float m1;
+    uint32_t n_head_log2;
+};
+
+struct vk_op_argsort_push_constants {
+    uint32_t ncols;
+    bool ascending;
+};
+
 // Allow pre-recording command buffers
 struct vk_staging_memcpy {
     vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -236,7 +325,6 @@ struct ggml_tensor_extra_gpu {
 };
 
 struct ggml_vk_garbage_collector {
-    std::vector<vk_pipeline *> pipelines;
     std::vector<vk_semaphore> tl_semaphores;
     std::vector<vk_semaphore> semaphores;
     std::vector<vk::Event> events;
@@ -247,35 +335,7 @@ struct ggml_vk_garbage_collector {
 struct ggml_backend_vk_context {
     std::string name;
 
-    std::weak_ptr<vk_device> device;
-    vk_pipeline pipeline_matmul_f32_l, pipeline_matmul_f32_m, pipeline_matmul_f32_s;
-    vk_pipeline pipeline_matmul_f32_aligned_l, pipeline_matmul_f32_aligned_m, pipeline_matmul_f32_aligned_s;
-    vk_pipeline pipeline_matmul_f16_l, pipeline_matmul_f16_m, pipeline_matmul_f16_s;
-    vk_pipeline pipeline_matmul_f16_aligned_l, pipeline_matmul_f16_aligned_m, pipeline_matmul_f16_aligned_s;
-    vk_pipeline pipeline_matmul_f16_f32_l, pipeline_matmul_f16_f32_m, pipeline_matmul_f16_f32_s;
-    vk_pipeline pipeline_matmul_f16_f32_aligned_l, pipeline_matmul_f16_f32_aligned_m, pipeline_matmul_f16_f32_aligned_s;
-    vk_pipeline pipeline_matmul_split_k_reduce;
-    vk_pipeline pipeline_dequant[VK_NUM_TYPES];
-    vk_pipeline pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
-    vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
-    vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
-    vk_pipeline pipeline_get_rows[VK_NUM_TYPES];
-    vk_pipeline pipeline_get_rows_f32[VK_NUM_TYPES];
-    vk_pipeline pipeline_mul_f32;
-    vk_pipeline pipeline_add_f32;
-    vk_pipeline pipeline_scale_f32;
-    vk_pipeline pipeline_sqr_f32;
-    vk_pipeline pipeline_clamp_f32;
-    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
-    vk_pipeline pipeline_norm_f32;
-    vk_pipeline pipeline_rms_norm_f32;
-    vk_pipeline pipeline_gelu_f32;
-    vk_pipeline pipeline_silu_f32;
-    vk_pipeline pipeline_relu_f32;
-    vk_pipeline pipeline_diag_mask_inf_f32;
-    vk_pipeline pipeline_soft_max_f32;
-    vk_pipeline pipeline_rope_f32, pipeline_rope_f16;
-    vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
+    std::shared_ptr<vk_device> device;
 
     size_t semaphore_idx, event_idx;
     ggml_vk_garbage_collector gc;
@@ -304,13 +364,31 @@ struct vk_instance {
 
     std::vector<size_t> device_indices;
 
-    std::shared_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
     ggml_backend_t backends[GGML_VK_MAX_DEVICES];
     ggml_backend_vk_context contexts[GGML_VK_MAX_DEVICES];
     ggml_backend_buffer_type buffer_types[GGML_VK_MAX_DEVICES];
     bool initialized[GGML_VK_MAX_DEVICES];
 };
 
+static std::shared_ptr<vk_device> ggml_vk_get_device(size_t idx) {
+#ifdef GGML_VULKAN_DEBUG
+    std::cerr << "ggml_vk_get_device(" << idx << ")" << std::endl;
+#endif
+    static std::weak_ptr<vk_device> devices[GGML_VK_MAX_DEVICES];
+
+    if (devices[idx].expired()) {
+#ifdef GGML_VULKAN_DEBUG
+    std::cerr << "Initializing new vk_device" << std::endl;
+#endif
+        std::shared_ptr<vk_device> device = std::make_shared<vk_device>();
+        device->initialized = false;
+        devices[idx] = device;
+        return device;
+    }
+
+    return devices[idx].lock();
+}
+
 #ifdef GGML_VULKAN_CHECK_RESULTS
 static size_t vk_skip_checks;
 static size_t vk_output_tensor;
@@ -334,14 +412,15 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
     GGML_ASSERT(parameter_count > 0);
     GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
 
-    pipeline.name = name;
-    pipeline.parameter_count = parameter_count;
-    pipeline.push_constant_size = push_constant_size;
-    pipeline.wg_denoms = wg_denoms;
-    pipeline.align = align;
+    pipeline = std::make_shared<vk_pipeline_struct>();
+    pipeline->name = name;
+    pipeline->parameter_count = parameter_count;
+    pipeline->push_constant_size = push_constant_size;
+    pipeline->wg_denoms = wg_denoms;
+    pipeline->align = align;
 
     vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
-    pipeline.shader_module = ctx->device.lock()->device.createShaderModule(shader_module_create_info);
+    pipeline->shader_module = ctx->device->device.createShaderModule(shader_module_create_info);
 
     std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
     std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
@@ -355,49 +434,49 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
     vk::PushConstantRange pcr(
         vk::ShaderStageFlagBits::eCompute,
         0,
-        pipeline.push_constant_size
+        pipeline->push_constant_size
     );
 
     vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
         {},
         dsl_binding);
     descriptor_set_layout_create_info.setPNext(&dslbfci);
-    pipeline.dsl = ctx->device.lock()->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
+    pipeline->dsl = ctx->device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
 
     // Check if device supports multiple descriptors per pool
-    if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
+    if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN) {
         const uint32_t alloc_count = 2;
 
         // Try allocating multiple sets from one pool
         // This fails on AMD for some reason, so add a fall back to allocating one pool per set
-        vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
+        vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
         vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, alloc_count, descriptor_pool_size);
-        vk::DescriptorPool pool = ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info);
+        vk::DescriptorPool pool = ctx->device->device.createDescriptorPool(descriptor_pool_create_info);
 
         std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
         for (uint32_t i = 0; i < alloc_count; i++) {
-            layouts[i] = pipeline.dsl;
+            layouts[i] = pipeline->dsl;
         }
         try {
             vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pool, alloc_count, layouts.data());
-            std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
+            std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
         } catch(vk::OutOfPoolMemoryError const&) {
-            ctx->device.lock()->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
+            ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_SINGLE;
         }
 
-        ctx->device.lock()->device.destroyDescriptorPool(pool);
+        ctx->device->device.destroyDescriptorPool(pool);
     }
 
-    if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
-        vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
+    if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
+        vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
         vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 128, descriptor_pool_size);
-        pipeline.descriptor_pools.push_back(ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info));
+        pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
     }
 
-    pipeline.descriptor_set_idx = 0;
+    pipeline->descriptor_set_idx = 0;
 
-    vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline.dsl, pcr);
-    pipeline.layout = ctx->device.lock()->device.createPipelineLayout(pipeline_layout_create_info);
+    vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr);
+    pipeline->layout = ctx->device->device.createPipelineLayout(pipeline_layout_create_info);
 
     std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
 
@@ -417,72 +496,75 @@ static void ggml_vk_create_pipeline(ggml_backend_vk_context * ctx, vk_pipeline&
     vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
             vk::PipelineShaderStageCreateFlags(),
             vk::ShaderStageFlagBits::eCompute,
-            pipeline.shader_module,
+            pipeline->shader_module,
             entrypoint.c_str(),
             &specialization_info);
     vk::ComputePipelineCreateInfo compute_pipeline_create_info(
         vk::PipelineCreateFlags(),
         pipeline_shader_create_info,
-        pipeline.layout);
-    pipeline.pipeline = ctx->device.lock()->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
+        pipeline->layout);
+    pipeline->pipeline = ctx->device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
 
-    ctx->gc.pipelines.push_back(&pipeline);
+    ctx->device->pipelines.push_back(pipeline);
 }
 
-static void ggml_vk_destroy_pipeline(ggml_backend_vk_context * ctx, vk_pipeline * pipeline) {
+static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
+#ifdef GGML_VULKAN_DEBUG
+    std::cerr << "ggml_pipeline_destroy_pipeline(" << pipeline->name << ")" << std::endl;
+#endif
     for (auto& pool : pipeline->descriptor_pools) {
-        ctx->device.lock()->device.destroyDescriptorPool(pool);
+        device.destroyDescriptorPool(pool);
     }
     pipeline->descriptor_pools.clear();
     pipeline->descriptor_sets.clear();
     pipeline->descriptor_set_idx = 0;
 
-    ctx->device.lock()->device.destroyDescriptorSetLayout(pipeline->dsl);
+    device.destroyDescriptorSetLayout(pipeline->dsl);
 
-    ctx->device.lock()->device.destroyPipelineLayout(pipeline->layout);
+    device.destroyPipelineLayout(pipeline->layout);
 
-    ctx->device.lock()->device.destroyShaderModule(pipeline->shader_module);
+    device.destroyShaderModule(pipeline->shader_module);
 
-    ctx->device.lock()->device.destroyPipeline(pipeline->pipeline);
+    device.destroyPipeline(pipeline->pipeline);
 }
 
 static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx, vk_pipeline& pipeline, uint32_t n) {
 #ifdef GGML_VULKAN_DEBUG
-    std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline.name << ", " << n << ")" << std::endl;
+    std::cerr << "ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")" << std::endl;
 #endif
-    if (pipeline.descriptor_sets.size() >= pipeline.descriptor_set_idx + n) {
+    if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) {
         // Enough descriptors are available
         return;
     }
 
-    if (ctx->device.lock()->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
-        const uint32_t alloc_count = pipeline.descriptor_set_idx + n - pipeline.descriptor_sets.size();
+    if (ctx->device->descriptor_set_mode == VK_DEVICE_DESCRIPTOR_POOL_MODE_MULTI) {
+        const uint32_t alloc_count = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size();
 
         std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
         for (uint32_t i = 0; i < alloc_count; i++) {
-            layouts[i] = pipeline.dsl;
+            layouts[i] = pipeline->dsl;
         }
-        vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pools[0], alloc_count, layouts.data());
-        std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
-        pipeline.descriptor_sets.insert(pipeline.descriptor_sets.end(), sets.begin(), sets.end());
+        vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[0], alloc_count, layouts.data());
+        std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+        pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end());
     } else {
-        for (uint32_t i = pipeline.descriptor_sets.size(); i < pipeline.descriptor_set_idx + n; i++) {
-            vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline.parameter_count);
+        for (uint32_t i = pipeline->descriptor_sets.size(); i < pipeline->descriptor_set_idx + n; i++) {
+            vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count);
             vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, 1, descriptor_pool_size);
-            pipeline.descriptor_pools.push_back(ctx->device.lock()->device.createDescriptorPool(descriptor_pool_create_info));
+            pipeline->descriptor_pools.push_back(ctx->device->device.createDescriptorPool(descriptor_pool_create_info));
 
-            vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline.descriptor_pools[i], 1, &pipeline.dsl);
-            std::vector<vk::DescriptorSet> sets = ctx->device.lock()->device.allocateDescriptorSets(descriptor_set_alloc_info);
-            pipeline.descriptor_sets.push_back(sets[0]);
+            vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[i], 1, &pipeline->dsl);
+            std::vector<vk::DescriptorSet> sets = ctx->device->device.allocateDescriptorSets(descriptor_set_alloc_info);
+            pipeline->descriptor_sets.push_back(sets[0]);
         }
     }
 }
 
 static void ggml_pipeline_cleanup(vk_pipeline& pipeline) {
 #ifdef GGML_VULKAN_DEBUG
-    std::cerr << "ggml_pipeline_cleanup(" << pipeline.name << ")" << std::endl;
+    std::cerr << "ggml_pipeline_cleanup(" << pipeline->name << ")" << std::endl;
 #endif
-    pipeline.descriptor_set_idx = 0;
+    pipeline->descriptor_set_idx = 0;
 }
 
 static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx, vk_queue& q) {
@@ -498,7 +580,7 @@ static vk::CommandBuffer ggml_vk_create_cmd_buffer(ggml_backend_vk_context * ctx
         q.pool,
         vk::CommandBufferLevel::ePrimary,
         1);
-    const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device.lock()->device.allocateCommandBuffers(command_buffer_alloc_info);
+    const std::vector<vk::CommandBuffer> cmd_buffers = ctx->device->device.allocateCommandBuffers(command_buffer_alloc_info);
     auto buf = cmd_buffers.front();
 
     q.cmd_buffers.push_back(buf);
@@ -643,11 +725,11 @@ static void ggml_vk_create_queue(ggml_backend_vk_context * ctx, vk_queue& q, uin
     q.queue_family_index = queue_family_index;
 
     vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index);
-    q.pool = ctx->device.lock()->device.createCommandPool(command_pool_create_info_compute);
+    q.pool = ctx->device->device.createCommandPool(command_pool_create_info_compute);
 
     q.cmd_buffer_idx = 0;
 
-    q.queue = ctx->device.lock()->device.getQueue(queue_family_index, queue_index);
+    q.queue = ctx->device->device.getQueue(queue_family_index, queue_index);
 
     q.stage_flags = stage_flags;
 }
@@ -671,7 +753,7 @@ static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context *
     vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
     vk::SemaphoreCreateInfo ci{};
     ci.setPNext(&tci);
-    vk::Semaphore semaphore = ctx->device.lock()->device.createSemaphore(ci);
+    vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
     ctx->gc.semaphores.push_back({ semaphore, 0 });
     return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
 }
@@ -684,7 +766,7 @@ static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context
         vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
         vk::SemaphoreCreateInfo ci{};
         ci.setPNext(&tci);
-        vk::Semaphore semaphore = ctx->device.lock()->device.createSemaphore(ci);
+        vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
         ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
     }
     return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
@@ -692,7 +774,7 @@ static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context
 
 static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
     if (ctx->event_idx >= ctx->gc.events.size()) {
-        ctx->gc.events.push_back(ctx->device.lock()->device.createEvent({}));
+        ctx->gc.events.push_back(ctx->device->device.createEvent({}));
     }
     return ctx->gc.events[ctx->event_idx++];
 }
@@ -703,7 +785,7 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
 #endif
     // Requires command buffers to be done
 
-    ctx->device.lock()->device.resetCommandPool(q.pool);
+    ctx->device->device.resetCommandPool(q.pool);
     q.cmd_buffer_idx = 0;
 }
 
@@ -740,11 +822,11 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
         nullptr,
     };
 
-    buf->buffer = ctx->device.lock()->device.createBuffer(buffer_create_info);
+    buf->buffer = ctx->device->device.createBuffer(buffer_create_info);
 
-    vk::MemoryRequirements mem_req = ctx->device.lock()->device.getBufferMemoryRequirements(buf->buffer);
+    vk::MemoryRequirements mem_req = ctx->device->device.getBufferMemoryRequirements(buf->buffer);
 
-    vk::PhysicalDeviceMemoryProperties mem_props = ctx->device.lock()->physical_device.getMemoryProperties();
+    vk::PhysicalDeviceMemoryProperties mem_props = ctx->device->physical_device.getMemoryProperties();
 
     uint32_t memory_type_index = UINT32_MAX;
 
@@ -757,30 +839,30 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
     }
 
     if (memory_type_index == UINT32_MAX) {
-        ctx->device.lock()->device.destroyBuffer(buf->buffer);
+        ctx->device->device.destroyBuffer(buf->buffer);
         buf->size = 0;
         throw vk::OutOfDeviceMemoryError("No suitable memory type found");
     }
 
     try {
-        buf->device_memory = ctx->device.lock()->device.allocateMemory({ mem_req.size, memory_type_index });
+        buf->device_memory = ctx->device->device.allocateMemory({ mem_req.size, memory_type_index });
     } catch (const vk::SystemError& e) {
         // Out of Host/Device memory, clean up buffer
-        ctx->device.lock()->device.destroyBuffer(buf->buffer);
+        ctx->device->device.destroyBuffer(buf->buffer);
         buf->size = 0;
         throw e;
     }
     buf->ptr = nullptr;
 
     if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
-        buf->ptr = ctx->device.lock()->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
+        buf->ptr = ctx->device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
     }
 
-    ctx->device.lock()->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
+    ctx->device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
 
     buf->ctx = ctx;
 
-    buf->device = ctx->device.lock();
+    buf->device = ctx->device;
 
 #ifdef GGML_VULKAN_DEBUG
     std::cerr << "Created buffer " << buf->buffer << std::endl;
@@ -802,7 +884,7 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
 static vk_buffer ggml_vk_create_buffer_device(ggml_backend_vk_context * ctx, size_t size) {
     vk_buffer buf;
     try {
-        if (ctx->device.lock()->uma) {
+        if (ctx->device->uma) {
             // Fall back to host memory type
             buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
         } else {
@@ -883,10 +965,16 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
     std::cerr << "ggml_vk_load_shaders(" << ctx->name << ")" << std::endl;
 #endif
 
+    const std::shared_ptr<vk_device> device = ctx->device;
+
     // mulmat
-    std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, ctx->device.lock()->subgroup_size * 2, 64, 2, 4, 4, ctx->device.lock()->subgroup_size };
-    std::initializer_list<uint32_t> warptile_m = { 128,  64,  64, 16, ctx->device.lock()->subgroup_size, 32, 2, 4, 2, ctx->device.lock()->subgroup_size };
-    std::initializer_list<uint32_t> warptile_s = { ctx->device.lock()->subgroup_size,  32,  32, 16, 32, 32, 2, 2, 2, ctx->device.lock()->subgroup_size };
+    std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
+    std::initializer_list<uint32_t> warptile_m = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
+    std::initializer_list<uint32_t> warptile_s = { device->subgroup_size,  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
+
+    std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
+    std::initializer_list<uint32_t> warptile_mmq_m = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
+    std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size,  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
 
     std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
     std::array<uint32_t, 3> m_wg_denoms = { 64,  64, 1 };
@@ -896,126 +984,206 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) {
     uint32_t m_align =  64;
     uint32_t s_align =  32;
 
-    if (ctx->device.lock()->fp16) {
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_l_len, matmul_f32_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_m_len, matmul_f32_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_s_len, matmul_f32_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_l_len, matmul_f32_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_m_len, matmul_f32_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_s_len, matmul_f32_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_l_len, matmul_f16_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_m_len, matmul_f16_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_s_len, matmul_f16_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_l_len, matmul_f16_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_m_len, matmul_f16_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_s_len, matmul_f16_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_l_len, matmul_f16_f32_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_m_len, matmul_f16_f32_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_s_len, matmul_f16_f32_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_len, matmul_f16_f32_aligned_l_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_len, matmul_f16_f32_aligned_m_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_len, matmul_f16_f32_aligned_s_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
+    ctx->device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+    ctx->device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
+    ctx->device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
+    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
+    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
+    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
+    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
+    ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
+
+    if (device->fp16) {
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_len, matmul_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_len, matmul_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_0_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_0_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_0_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
     } else {
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_l, "matmul_f32_l", matmul_f32_l_fp32_len, matmul_f32_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_m, "matmul_f32_m", matmul_f32_m_fp32_len, matmul_f32_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_s, "matmul_f32_s", matmul_f32_s_fp32_len, matmul_f32_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_l, "matmul_f32_aligned_l", matmul_f32_aligned_l_fp32_len, matmul_f32_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_m, "matmul_f32_aligned_m", matmul_f32_aligned_m_fp32_len, matmul_f32_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f32_aligned_s, "matmul_f32_aligned_s", matmul_f32_aligned_s_fp32_len, matmul_f32_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_l, "matmul_f16_l", matmul_f16_l_fp32_len, matmul_f16_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_m, "matmul_f16_m", matmul_f16_m_fp32_len, matmul_f16_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_s, "matmul_f16_s", matmul_f16_s_fp32_len, matmul_f16_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_l, "matmul_f16_aligned_l", matmul_f16_aligned_l_fp32_len, matmul_f16_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_m, "matmul_f16_aligned_m", matmul_f16_aligned_m_fp32_len, matmul_f16_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_aligned_s, "matmul_f16_aligned_s", matmul_f16_aligned_s_fp32_len, matmul_f16_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
-
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_l, "matmul_f16_f32_l", matmul_f16_f32_l_fp32_len, matmul_f16_f32_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_m, "matmul_f16_f32_m", matmul_f16_f32_m_fp32_len, matmul_f16_f32_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_s, "matmul_f16_f32_s", matmul_f16_f32_s_fp32_len, matmul_f16_f32_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_l_fp32_len, matmul_f16_f32_aligned_l_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_m_fp32_len, matmul_f16_f32_aligned_m_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
-        ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_f16_f32_aligned_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_s_fp32_len, matmul_f16_f32_aligned_s_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
-    }
-
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32",  mul_mat_vec_f16_f32_len,  mul_mat_vec_f16_f32_data,  "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(int), {1, 1, 1}, {}, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_fp32_len, matmul_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_aligned_fp32_len, matmul_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, 1);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), l_wg_denoms, warptile_mmq_l, l_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), m_wg_denoms, warptile_mmq_m, m_align);
+        ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, 14 * sizeof(uint32_t), s_wg_denoms, warptile_mmq_s, s_align);
+    }
+
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32",  mul_mat_vec_f16_f32_len,  mul_mat_vec_f16_f32_data,  "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32", mul_mat_vec_q4_0_f32_len, mul_mat_vec_q4_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32", mul_mat_vec_q4_1_f32_len, mul_mat_vec_q4_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32", mul_mat_vec_q5_0_f32_len, mul_mat_vec_q5_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32", mul_mat_vec_q5_1_f32_len, mul_mat_vec_q5_1_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32", mul_mat_vec_q8_0_f32_len, mul_mat_vec_q8_0_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_K_f32", mul_mat_vec_q2_K_f32_len, mul_mat_vec_q2_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_K_f32", mul_mat_vec_q3_K_f32_len, mul_mat_vec_q3_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_K_f32", mul_mat_vec_q4_K_f32_len, mul_mat_vec_q4_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_K_f32", mul_mat_vec_q5_K_f32_len, mul_mat_vec_q5_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant_mul_mat_vec_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_K_f32", mul_mat_vec_q6_K_f32_len, mul_mat_vec_q6_K_f32_data, "main", 3, 3 * sizeof(uint32_t), {1, 1, 1}, { device->subgroup_size }, 1);
 
     // dequant shaders
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   f32_to_f16_len,   f32_to_f16_data,   "main", 2, 4 * sizeof(int), {      64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_F16 ], "dequant_f16",  dequant_f16_len,  dequant_f16_data,  "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 4 * sizeof(int), {256 * 32, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 4 * sizeof(int), {256 * 64, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   dequant_f32_len,  dequant_f32_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_K", dequant_q2_K_len, dequant_q2_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_K", dequant_q3_K_len, dequant_q3_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_K", dequant_q4_K_len, dequant_q4_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_K", dequant_q5_K_len, dequant_q5_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_K", dequant_q6_K_len, dequant_q6_K_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
 
     // get_rows
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16",  get_rows_f16_len,  get_rows_f16_data,  "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16",  get_rows_f16_len,  get_rows_f16_data,  "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_cpy_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
-    ggml_vk_create_pipeline(ctx, ctx->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 3, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1);
+
+    ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
 }
 
 static void ggml_vk_print_gpu_info(size_t idx) {
@@ -1057,8 +1225,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
         }
     }
 
-    const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
-    bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
+    const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
+    bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
 
     bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
 
@@ -1188,140 +1356,152 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
         throw std::runtime_error("Device not found");
     }
 
-    vk_instance.devices[idx] = std::make_shared<vk_device>();
-    ctx->device = vk_instance.devices[idx];
-    ctx->device.lock()->physical_device = devices[dev_num];
-    const std::vector<vk::ExtensionProperties> ext_props = ctx->device.lock()->physical_device.enumerateDeviceExtensionProperties();
+    ctx->device = ggml_vk_get_device(idx);
+    if (!ctx->device->initialized) {
+        ctx->device->physical_device = devices[dev_num];
+        const std::vector<vk::ExtensionProperties> ext_props = ctx->device->physical_device.enumerateDeviceExtensionProperties();
 
-    bool maintenance4_support = false;
+        bool maintenance4_support = false;
 
-    // Check if maintenance4 is supported
-    for (const auto& properties : ext_props) {
-        if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
-            maintenance4_support = true;
+        // Check if maintenance4 is supported
+        for (const auto& properties : ext_props) {
+            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
+                maintenance4_support = true;
+            }
         }
-    }
 
-    vk::PhysicalDeviceProperties2 props2;
-    vk::PhysicalDeviceMaintenance3Properties props3;
-    vk::PhysicalDeviceMaintenance4Properties props4;
-    vk::PhysicalDeviceSubgroupProperties subgroup_props;
-    props2.pNext = &props3;
-    props3.pNext = &subgroup_props;
-    if (maintenance4_support) {
-        subgroup_props.pNext = &props4;
-    }
-    ctx->device.lock()->physical_device.getProperties2(&props2);
-    ctx->device.lock()->properties = props2.properties;
+        vk::PhysicalDeviceProperties2 props2;
+        vk::PhysicalDeviceMaintenance3Properties props3;
+        vk::PhysicalDeviceMaintenance4Properties props4;
+        vk::PhysicalDeviceSubgroupProperties subgroup_props;
+        props2.pNext = &props3;
+        props3.pNext = &subgroup_props;
+        if (maintenance4_support) {
+            subgroup_props.pNext = &props4;
+        }
+        ctx->device->physical_device.getProperties2(&props2);
+        ctx->device->properties = props2.properties;
 
-    if (maintenance4_support) {
-        ctx->device.lock()->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
-    } else {
-        ctx->device.lock()->max_memory_allocation_size = props3.maxMemoryAllocationSize;
-    }
+        const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
 
-    ctx->device.lock()->vendor_id = ctx->device.lock()->properties.vendorID;
-    ctx->device.lock()->subgroup_size = subgroup_props.subgroupSize;
-    ctx->device.lock()->uma = ctx->device.lock()->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
+        if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
+            ctx->device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
+        } else if (maintenance4_support) {
+            ctx->device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
+        } else {
+            ctx->device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
+        }
 
-    bool fp16_storage = false;
-    bool fp16_compute = false;
+        ctx->device->vendor_id = ctx->device->properties.vendorID;
+        ctx->device->subgroup_size = subgroup_props.subgroupSize;
+        ctx->device->uma = ctx->device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
 
-    for (const auto& properties : ext_props) {
-        if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
-            fp16_storage = true;
-        } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
-            fp16_compute = true;
+        bool fp16_storage = false;
+        bool fp16_compute = false;
+
+        for (const auto& properties : ext_props) {
+            if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
+                fp16_storage = true;
+            } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
+                fp16_compute = true;
+            }
         }
-    }
 
-    const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
-    bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != nullptr;
+        const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
+        const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
 
-    ctx->device.lock()->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
+        ctx->device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
 
-    std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device.lock()->physical_device.getQueueFamilyProperties();
+        std::vector<vk::QueueFamilyProperties> queue_family_props = ctx->device->physical_device.getQueueFamilyProperties();
 
-    // Try to find a non-graphics compute queue and transfer-focused queues
-    const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
-    const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
+        // Try to find a non-graphics compute queue and transfer-focused queues
+        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
+        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
 
-    const float priorities[] = { 1.0f, 1.0f };
-    ctx->device.lock()->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
+        const float priorities[] = { 1.0f, 1.0f };
+        ctx->device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
 
-    std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
-    if (compute_queue_family_index != transfer_queue_family_index) {
-        device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
-        device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
-    } else if(!ctx->device.lock()->single_queue) {
-        device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
-    } else {
-        device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
-    }
-    vk::DeviceCreateInfo device_create_info;
-    std::vector<const char *> device_extensions;
-    vk::PhysicalDeviceFeatures device_features = ctx->device.lock()->physical_device.getFeatures();
+        std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
+        if (compute_queue_family_index != transfer_queue_family_index) {
+            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
+            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
+        } else if(!ctx->device->single_queue) {
+            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
+        } else {
+            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
+        }
+        vk::DeviceCreateInfo device_create_info;
+        std::vector<const char *> device_extensions;
+        vk::PhysicalDeviceFeatures device_features = ctx->device->physical_device.getFeatures();
 
-    VkPhysicalDeviceFeatures2 device_features2;
-    device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
-    device_features2.pNext = nullptr;
-    device_features2.features = (VkPhysicalDeviceFeatures)device_features;
+        VkPhysicalDeviceFeatures2 device_features2;
+        device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
+        device_features2.pNext = nullptr;
+        device_features2.features = (VkPhysicalDeviceFeatures)device_features;
 
-    VkPhysicalDeviceVulkan11Features vk11_features;
-    vk11_features.pNext = nullptr;
-    vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
-    device_features2.pNext = &vk11_features;
+        VkPhysicalDeviceVulkan11Features vk11_features;
+        vk11_features.pNext = nullptr;
+        vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
+        device_features2.pNext = &vk11_features;
 
-    VkPhysicalDeviceVulkan12Features vk12_features;
-    vk12_features.pNext = nullptr;
-    vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
-    vk11_features.pNext = &vk12_features;
+        VkPhysicalDeviceVulkan12Features vk12_features;
+        vk12_features.pNext = nullptr;
+        vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
+        vk11_features.pNext = &vk12_features;
 
-    vkGetPhysicalDeviceFeatures2(ctx->device.lock()->physical_device, &device_features2);
+        vkGetPhysicalDeviceFeatures2(ctx->device->physical_device, &device_features2);
 
-    ctx->device.lock()->fp16 = ctx->device.lock()->fp16 && vk12_features.shaderFloat16;
+        ctx->device->fp16 = ctx->device->fp16 && vk12_features.shaderFloat16;
 
-    if (!vk11_features.storageBuffer16BitAccess) {
-        std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
-        throw std::runtime_error("Unsupported device");
-    }
+        if (!vk11_features.storageBuffer16BitAccess) {
+            std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
+            throw std::runtime_error("Unsupported device");
+        }
 
-    device_extensions.push_back("VK_KHR_16bit_storage");
+        device_extensions.push_back("VK_KHR_16bit_storage");
 
 #ifdef GGML_VULKAN_VALIDATE
-    device_extensions.push_back("VK_KHR_shader_non_semantic_info");
+        device_extensions.push_back("VK_KHR_shader_non_semantic_info");
 #endif
 
-    if (ctx->device.lock()->fp16) {
-        device_extensions.push_back("VK_KHR_shader_float16_int8");
-    }
-    ctx->device.lock()->name = ctx->device.lock()->properties.deviceName.data();
+        if (ctx->device->fp16) {
+            device_extensions.push_back("VK_KHR_shader_float16_int8");
+        }
+        ctx->device->name = ctx->device->properties.deviceName.data();
 
-    device_create_info = {
-        vk::DeviceCreateFlags(),
-        device_queue_create_infos,
-        {},
-        device_extensions
-    };
-    device_create_info.setPNext(&device_features2);
-    ctx->device.lock()->device = ctx->device.lock()->physical_device.createDevice(device_create_info);
+        device_create_info = {
+            vk::DeviceCreateFlags(),
+            device_queue_create_infos,
+            {},
+            device_extensions
+        };
+        device_create_info.setPNext(&device_features2);
+        ctx->device->device = ctx->device->physical_device.createDevice(device_create_info);
 
-    ctx->device.lock()->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
+        ctx->device->descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN;
 
-    // Shaders
-    ggml_vk_load_shaders(ctx);
+        // Queues
+        ggml_vk_create_queue(ctx, ctx->device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
 
-    // Queues
-    ggml_vk_create_queue(ctx, ctx->device.lock()->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer });
-    if (!ctx->device.lock()->single_queue) {
-        const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
-        ggml_vk_create_queue(ctx, ctx->device.lock()->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
-    } else {
-        // TODO: Use pointer or reference to avoid copy
-        ctx->device.lock()->transfer_queue = ctx->device.lock()->compute_queue;
+        // Shaders
+        ggml_vk_load_shaders(ctx);
+
+        if (!ctx->device->single_queue) {
+            const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
+            ggml_vk_create_queue(ctx, ctx->device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer });
+        } else {
+            // TODO: Use pointer or reference to avoid copy
+            ctx->device->transfer_queue = ctx->device->compute_queue;
+        }
+
+        ctx->device->idx = dev_num;
+        ctx->device->initialized = true;
+    } else if (ctx->device->idx != dev_num) {
+        std::cerr << "ggml_vulkan: Device " << ctx->device->name << " already initialized with index " << ctx->device->idx << ", but trying to reinitialize with index " << dev_num << std::endl;
+        throw std::runtime_error("Device already initialized");
     }
 
-    ctx->fence = ctx->device.lock()->device.createFence({});
+    ctx->fence = ctx->device->device.createFence({});
 
     ctx->compute_ctx = nullptr;
     ctx->transfer_ctx = nullptr;
@@ -1339,7 +1519,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
 #endif
 }
 
-static vk_pipeline* ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
+static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
 #ifdef GGML_VULKAN_DEBUG
     std::cerr << "ggml_vk_get_to_fp16()" << std::endl;
 #endif
@@ -1360,10 +1540,36 @@ static vk_pipeline* ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
             return nullptr;
     }
 
-    return &ctx->pipeline_dequant[type];
+    return ctx->device->pipeline_dequant[type];
 }
 
-static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
+static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
+#ifdef GGML_VULKAN_DEBUG
+    std::cerr << "ggml_vk_get_mul_mat_mat_pipeline()" << std::endl;
+#endif
+    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+        return ctx->device->pipeline_matmul_f32;
+    }
+    if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+        return ctx->device->pipeline_matmul_f16_f32;
+    }
+    if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+        return ctx->device->pipeline_matmul_f16;
+    }
+
+    GGML_ASSERT(src1_type == GGML_TYPE_F32);
+
+    switch (src0_type) {
+        case GGML_TYPE_Q4_0:
+            break;
+        default:
+            return nullptr;
+    }
+
+    return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
+}
+
+static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type type) {
 #ifdef GGML_VULKAN_DEBUG
     std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl;
 #endif
@@ -1384,7 +1590,7 @@ static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
             return nullptr;
     }
 
-    return &ctx->pipeline_dequant_mul_mat_vec_f32[type];
+    return ctx->device->pipeline_dequant_mul_mat_vec_f32[type];
 }
 
 static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
@@ -1463,8 +1669,8 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
     if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
         fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
             size/1024.0/1024.0);
-        ctx->device.lock()->device.freeMemory(buf->device_memory);
-        ctx->device.lock()->device.destroyBuffer(buf->buffer);
+        ctx->device->device.freeMemory(buf->device_memory);
+        ctx->device->device.destroyBuffer(buf->buffer);
         return nullptr;
     }
 
@@ -1528,30 +1734,30 @@ static vk_submission ggml_vk_begin_submission(ggml_backend_vk_context * ctx, vk_
 }
 
 static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, std::vector<vk_subbuffer>&& buffers, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
-    const uint32_t wg0 = CEIL_DIV(elements[0], pipeline.wg_denoms[0]);
-    const uint32_t wg1 = CEIL_DIV(elements[1], pipeline.wg_denoms[1]);
-    const uint32_t wg2 = CEIL_DIV(elements[2], pipeline.wg_denoms[2]);
+    const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
+    const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
+    const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
 #ifdef GGML_VULKAN_DEBUG
-    std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline.name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
+    std::cerr << "ggml_vk_dispatch_pipeline(" << pipeline->name << ", (" << wg0 << "," << wg1 << "," << wg2 << "))" << std::endl;
 #endif
     std::vector<vk::DescriptorBufferInfo> descriptor_buffer_infos;
     std::vector<vk::WriteDescriptorSet> write_descriptor_sets;
-    GGML_ASSERT(pipeline.descriptor_set_idx < pipeline.descriptor_sets.size());
-    GGML_ASSERT(buffers.size() == pipeline.parameter_count);
-    vk::DescriptorSet& descriptor_set = pipeline.descriptor_sets[pipeline.descriptor_set_idx++];
-    for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
+    GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size());
+    GGML_ASSERT(buffers.size() == pipeline->parameter_count);
+    vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++];
+    for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
         descriptor_buffer_infos.push_back({buffers[i].buffer->buffer, buffers[i].offset, buffers[i].size});
     }
-    for (uint32_t i = 0; i < pipeline.parameter_count; i++) {
+    for (uint32_t i = 0; i < pipeline->parameter_count; i++) {
         write_descriptor_sets.push_back({descriptor_set, i, 0, 1, vk::DescriptorType::eStorageBuffer, nullptr, &descriptor_buffer_infos[i]});
     }
 
-    ctx->device.lock()->device.updateDescriptorSets(write_descriptor_sets, {});
+    ctx->device->device.updateDescriptorSets(write_descriptor_sets, {});
 
-    subctx->s->buffer.pushConstants(pipeline.layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
-    subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline.pipeline);
+    subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants);
+    subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
     subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
-                                pipeline.layout,
+                                pipeline->layout,
                                 0,
                                 { descriptor_set },
                                 {});
@@ -1810,7 +2016,7 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
             memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
         }
     } else {
-        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
+        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
         ggml_vk_ctx_begin(ctx, subctx);
         ggml_vk_buffer_write_2d_async(ctx, subctx, dst, offset, src, spitch, width, height, true);
         ggml_vk_ctx_end(subctx);
@@ -1820,8 +2026,9 @@ static void ggml_vk_buffer_write_2d(ggml_backend_vk_context * ctx, vk_buffer& ds
         }
 
         ggml_vk_submit(subctx, ctx->fence);
-        VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
-        ctx->device.lock()->device.resetFences({ ctx->fence });
+        VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
+        ctx->device->device.resetFences({ ctx->fence });
+        ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
     }
 }
 
@@ -1906,18 +2113,19 @@ static void ggml_vk_buffer_read(ggml_backend_vk_context * ctx, vk_buffer& src, s
 
         memcpy(dst, (uint8_t *) src->ptr + offset, size);
     } else {
-        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
+        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
         ggml_vk_ctx_begin(ctx, subctx);
         ggml_vk_buffer_read_async(ctx, subctx, src, offset, dst, size, true);
         ggml_vk_ctx_end(subctx);
 
         ggml_vk_submit(subctx, ctx->fence);
-        VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
-        ctx->device.lock()->device.resetFences({ ctx->fence });
+        VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
+        ctx->device->device.resetFences({ ctx->fence });
 
         for (auto& cpy : subctx->out_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
+        ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
     }
 }
 
@@ -1941,15 +2149,13 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
         // Copy within the device
         ggml_backend_vk_context * ctx = src->ctx;
 
-        VkBufferCopy bc{ src_offset, dst_offset, size };
-
-        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
+        vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
         ggml_vk_ctx_begin(ctx, subctx);
         ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
         ggml_vk_ctx_end(subctx);
         ggml_vk_submit(subctx, ctx->fence);
-        VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
-        ctx->device.lock()->device.resetFences({ ctx->fence });
+        VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
+        ctx->device->device.resetFences({ ctx->fence });
     } else {
 #ifdef GGML_VULKAN_DEBUG
     std::cerr << "ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")" << std::endl;
@@ -1977,14 +2183,14 @@ static void ggml_vk_buffer_memset(ggml_backend_vk_context * ctx, vk_buffer& dst,
     // Make sure ctx owns the buffer
     GGML_ASSERT(dst->ctx == ctx);
 
-    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
+    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
     ggml_vk_ctx_begin(ctx, subctx);
     subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
     ggml_vk_ctx_end(subctx);
 
     ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
-    ctx->device.lock()->device.resetFences({ ctx->fence });
+    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "vk_memset waitForFences");
+    ctx->device->device.resetFences({ ctx->fence });
 }
 
 static void ggml_vk_h2d_tensor_2d(ggml_backend_vk_context * ctx, vk_context * subctx, vk_buffer& dst, size_t offset, const ggml_tensor * src, uint64_t i3, uint64_t i2, uint64_t i1) {
@@ -2045,176 +2251,63 @@ static void ggml_vk_d2h_tensor_2d(ggml_backend_vk_context * ctx, vk_context * su
 
 static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
 #ifdef GGML_VULKAN_DEBUG
-    std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")";
+    std::cerr << "ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")" << std::endl;
 #endif
     if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " = 4" << std::endl;
-#endif
         return 4;
     }
 
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " = 1" << std::endl;
-#endif
     return 1;
 }
 
-static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, int m, int n) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
-#endif
+static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
     if (m <= 32 || n <= 32) {
-        return ctx->pipeline_matmul_f32_aligned_s.align;
+        return aligned ? mmp->a_s : mmp->s;
     }
-    if (ctx->device.lock()->subgroup_size == 64 || m <= 64 || n <= 64) {
-        return ctx->pipeline_matmul_f32_aligned_m.align;
-    }
-    return ctx->pipeline_matmul_f32_aligned_l.align;
+    return aligned ? mmp->a_m : mmp->m;
+
+    GGML_UNUSED(ctx);
 }
 
-static vk_pipeline* ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
-    if (bit16_x && bit16_y) {
-        if (m <= 32 || n <= 32) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " S" << std::endl;
-#endif
-            return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
-        }
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " M" << std::endl;
-#endif
-        return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
-    }
-    if (bit16_x && !bit16_y) {
-        if (m <= 32 || n <= 32) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " S" << std::endl;
-#endif
-            return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
-        }
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " M" << std::endl;
-#endif
-        return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
-    }
-    if (!bit16_x && bit16_y) {
-        GGML_ASSERT(false);
-    }
+static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
+    return aligned ? mmp->a_m : mmp->m;
 
-    if (m <= 32 || n <= 32) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " S" << std::endl;
-#endif
-        return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
-    }
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " M" << std::endl;
-#endif
-    return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
+    GGML_UNUSED(ctx);
 }
 
-static vk_pipeline* ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " M" << std::endl;
-#endif
-    if (bit16_x && bit16_y) {
-        return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
-    }
-    if (bit16_x && !bit16_y) {
-        return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
-    }
-    if (!bit16_x && bit16_y) {
-        GGML_ASSERT(false);
-    }
-    return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
-}
+static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
+    return aligned ? mmp->a_s : mmp->s;
 
-static vk_pipeline* ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, bool aligned) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " S" << std::endl;
-#endif
-    if (bit16_x && bit16_y) {
-        return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
-    }
-    if (bit16_x && !bit16_y) {
-        return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
-    }
-    if (!bit16_x && bit16_y) {
-        GGML_ASSERT(false);
-    }
-    return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
+    GGML_UNUSED(ctx);
 }
 
-static vk_pipeline* ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, bool bit16_x, bool bit16_y, int m, int n, bool aligned) {
+static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
 #ifdef GGML_VULKAN_DEBUG
-    std::cerr << "ggml_vk_guess_matmul_pipeline(" << bit16_x << ", " << bit16_y << ", " << m << ", " << n << ", " << aligned << ")";
+    std::cerr << "ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")" << std::endl;
 #endif
-    switch (ctx->device.lock()->vendor_id) {
+    switch (ctx->device->vendor_id) {
     case VK_VENDOR_ID_AMD:
-        return ggml_vk_guess_matmul_pipeline_amd(ctx, bit16_x, bit16_y, m, n, aligned);
+        return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
     case VK_VENDOR_ID_APPLE:
-        return ggml_vk_guess_matmul_pipeline_apple(ctx, bit16_x, bit16_y, aligned);
+        return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
     case VK_VENDOR_ID_INTEL:
-        return ggml_vk_guess_matmul_pipeline_intel(ctx, bit16_x, bit16_y, aligned);
-    }
-
-    if (bit16_x && bit16_y) {
-        if (m <= 32 || n <= 32) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " S" << std::endl;
-#endif
-            return aligned ? &ctx->pipeline_matmul_f16_aligned_s : &ctx->pipeline_matmul_f16_s;
-        }
-        if (m <= 64 || n <= 64) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " M" << std::endl;
-#endif
-            return aligned ? &ctx->pipeline_matmul_f16_aligned_m : &ctx->pipeline_matmul_f16_m;
-        }
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " L" << std::endl;
-#endif
-        return aligned ? &ctx->pipeline_matmul_f16_aligned_l : &ctx->pipeline_matmul_f16_l;
-    }
-    if (bit16_x && !bit16_y) {
-        if (m <= 32 || n <= 32) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " S" << std::endl;
-#endif
-            return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_s : &ctx->pipeline_matmul_f16_f32_s;
-        }
-        if (m <= 64 || n <= 64) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " M" << std::endl;
-#endif
-            return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_m : &ctx->pipeline_matmul_f16_f32_m;
-        }
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " L" << std::endl;
-#endif
-        return aligned ? &ctx->pipeline_matmul_f16_f32_aligned_l : &ctx->pipeline_matmul_f16_f32_l;
-    }
-    if (!bit16_x && bit16_y) {
-        GGML_ASSERT(false);
+        return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
     }
 
     if (m <= 32 || n <= 32) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " S" << std::endl;
-#endif
-        return aligned ? &ctx->pipeline_matmul_f32_aligned_s : &ctx->pipeline_matmul_f32_s;
+        return aligned ? mmp->a_s : mmp->s;
     }
     if (m <= 64 || n <= 64) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << " M" << std::endl;
-#endif
-        return aligned ? &ctx->pipeline_matmul_f32_aligned_m : &ctx->pipeline_matmul_f32_m;
+        return aligned ? mmp->a_m : mmp->m;
     }
+    return aligned ? mmp->a_l : mmp->l;
+}
+
+static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
 #ifdef GGML_VULKAN_DEBUG
-    std::cerr << " L" << std::endl;
+    std::cerr << "ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")" << std::endl;
 #endif
-    return aligned ? &ctx->pipeline_matmul_f32_aligned_l : &ctx->pipeline_matmul_f32_l;
+    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, false)->align;
 }
 
 static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d) {
@@ -2232,10 +2325,10 @@ static void ggml_vk_matmul(ggml_backend_vk_context * ctx, vk_context * subctx, v
 
     const std::array<uint32_t, 14> pc1 = { m, n, k, stride_a, stride_b, stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, batch_stride_a, batch_stride_b, batch_stride_d };
     // Make sure enough workgroups get assigned for split k to work
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline.wg_denoms[0]) * pipeline.wg_denoms[0]) * split_k, n, batch });
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1.size() * sizeof(uint32_t), pc1.data(), { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
     ggml_vk_sync_buffers(subctx);
     const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
+    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
 }
 
 static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
@@ -2245,41 +2338,39 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
 
-static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
+static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
     if (from == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
-        return &ctx->pipeline_cpy_f32_f32;
+        return ctx->device->pipeline_cpy_f32_f32;
     }
     if (from == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
-        return &ctx->pipeline_cpy_f32_f16;
+        return ctx->device->pipeline_cpy_f32_f16;
     }
     if (from == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
-        return &ctx->pipeline_cpy_f16_f16;
+        return ctx->device->pipeline_cpy_f16_f16;
     }
 
     std::cerr << "Missing CPY op for types: " << ggml_type_name(from) << " " << ggml_type_name(to) << std::endl;
     GGML_ASSERT(false);
 }
 
-static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline * pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out, ggml_type buffer_type, bool aligned=true) {
+static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context * subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) {
 #ifdef GGML_VULKAN_DEBUG
     std::cerr << "ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", backend=" << tensor->backend << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
     std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")" << std::endl;
 #endif
     const int tensor_type_size = ggml_type_size(tensor->type);
-    const int dst_type_size = ggml_type_size(buffer_type);
-
-    const uint32_t ne = tensor->ne[0] * tensor->ne[1] * tensor->ne[2];
 
-    const uint32_t nb2 = aligned ? ggml_vk_align_size(dst_type_size * tensor->ne[0] * tensor->ne[1], ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size : tensor->ne[0] * tensor->ne[1];
+    const uint32_t ne = ggml_nelements(tensor);
 
-    const vk_op_cpy_push_constants pc = {
+    const vk_op_unary_push_constants pc = {
         (uint32_t)ne,
-        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size,
-        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1],                       1                   , (uint32_t)tensor->ne[0]                   , nb2,
+        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
+        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3],                       1                   , (uint32_t)tensor->ne[0]                   , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
         0,
+        0.0f, 0.0f,
     };
     ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { in, out }, sizeof(vk_op_cpy_push_constants), &pc, { ne, 1, 1 });
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
 }
 
 static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2319,7 +2410,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     bool src0_uma = false;
     bool src1_uma = false;
 
-    if (ctx->device.lock()->uma) {
+    if (ctx->device->uma) {
         ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
         ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
         src0_uma = d_Qx != nullptr;
@@ -2332,10 +2423,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     const bool x_non_contig = !load_x && !ggml_vk_dim01_contiguous(src0);
     const bool y_non_contig = !load_y && !ggml_vk_dim01_contiguous(src1);
 
-    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
+    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
-    const bool qx_needs_dequant = src0->type != GGML_TYPE_F16 || x_non_contig;
-    const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
+    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
+
+    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
+    const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
+
+    if (mmp == nullptr) {
+        // Fall back to dequant + f16 mulmat
+        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
+    }
 
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
@@ -2344,17 +2442,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     const int y_ne = ne11 * ne10;
     const int d_ne = ne11 * ne01;
 
-    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, ne01, ne11));
+    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
     const bool aligned = ne10 == kpad;
 
     const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
 
-    vk_pipeline * pipeline = ggml_vk_guess_matmul_pipeline(ctx, true, !f16_f32_kernel, ne01, ne11, aligned);
+    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
 
     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
-    const uint64_t x_sz = sizeof(ggml_fp16_t) * x_ne;
-    const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
+    const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
     const uint64_t d_sz = sizeof(float) * d_ne;
 
     vk_buffer d_D = extra->buffer_gpu.lock();
@@ -2385,7 +2483,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     } else {
         d_X = d_Qx;
         x_buf_offset = qx_buf_offset;
-        GGML_ASSERT(qx_sz == x_sz);  // NOLINT
+        GGML_ASSERT(qx_sz == x_sz);
     }
     if (qy_needs_dequant) {
         d_Y = ctx->prealloc_y;
@@ -2396,8 +2494,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
         GGML_ASSERT(qy_sz == y_sz);
     }
 
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline to_fp16_vk_1 = nullptr;
+    vk_pipeline to_fp16_vk_0 = nullptr;
+    vk_pipeline to_fp16_vk_1 = nullptr;
 
     if (x_non_contig) {
         to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
@@ -2413,19 +2511,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 
     // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne12 * ne13);
+    ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
     if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, x_non_contig ? 1 : ne12 * ne13);
+        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
     }
     if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
+        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, 1);
     }
     if (split_k > 1) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, ne12 * ne13);
+        ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
     }
 
     if (x_non_contig) {
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }, dst->type, false);
+        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
     } else if (load_x || qx_needs_dequant) {
         if (load_x) {
             // copy data to device
@@ -2434,13 +2532,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
         }
 
         if (qx_needs_dequant) {
-            const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
+            const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
             ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
+            ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { { d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, { d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
         }
     }
     if (y_non_contig) {
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, dst->type);
+        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
     } else if (load_y) {
         ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qy, 0, src1, 0, 0, ggml_nrows(src1));
     }
@@ -2457,7 +2555,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context * su
     }
 
     // compute
-    ggml_vk_matmul(ctx, subctx, *pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21);  // NOLINT
+    ggml_vk_matmul(ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, split_k, ne12*ne13, ne02, ne12, r2, r3, stride_batch_x, stride_batch_y, ne20*ne21);  // NOLINT
 
     if (dst->backend == GGML_BACKEND_TYPE_CPU) {
         // copy dst to host
@@ -2505,7 +2603,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
     bool src0_uma = false;
     bool src1_uma = false;
 
-    if (ctx->device.lock()->uma) {
+    if (ctx->device->uma) {
         ggml_vk_host_get(ctx, src0->data, d_Qx, qx_buf_offset);
         ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
         src0_uma = d_Qx != nullptr;
@@ -2527,9 +2625,9 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
     const uint64_t y_ne = ne11 * ne10;
     const uint64_t d_ne = ne11 * ne01;
 
-    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
+    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
-    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
+    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
     const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
     const uint64_t d_sz = sizeof(float) * d_ne;
 
@@ -2569,8 +2667,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
         GGML_ASSERT(qy_sz == y_sz);
     }
 
-    vk_pipeline to_fp16_vk_0 = nullptr;
-    vk_pipeline* to_fp16_vk_1 = nullptr;
+    vk_pipeline to_fp16_vk_0 = nullptr;
+    vk_pipeline to_fp16_vk_1 = nullptr;
     if (x_non_contig) {
         to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
     }
@@ -2579,30 +2677,30 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
     } else {
         to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
     }
-    vk_pipeline* dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
+    vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type);
     GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
     GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
     GGML_ASSERT(dmmv != nullptr);
 
     // Allocate descriptor sets
     if (qx_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_0, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_0, 1);
     }
     if (qy_needs_dequant) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, *to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
+        ggml_pipeline_allocate_descriptor_sets(ctx, to_fp16_vk_1, y_non_contig ? 1 : ne12 * ne13);
     }
-    ggml_pipeline_allocate_descriptor_sets(ctx, *dmmv, ne12 * ne13);
+    ggml_pipeline_allocate_descriptor_sets(ctx, dmmv, ne12 * ne13);
 
     if (x_non_contig) {
-        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment));
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE }, src0->type);
+        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
+        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
     } else if (load_x) {
         // copy data to device
         ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qx, 0, src0, 0, 0, ggml_nrows(src0));
     }
     if (y_non_contig) {
         GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
-        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, src1->type);
+        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
     } else if (load_y) {
         ggml_vk_h2d_tensor_2d(ctx, subctx, d_Qy, 0, src1, 0, 0, ggml_nrows(src1));
     }
@@ -2619,22 +2717,22 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context
             const uint64_t y_offset = y_buf_offset + y_sz * it_idx1;
             const uint64_t d_offset = d_buf_offset + d_sz * it_idx1;
 
-            const uint64_t y_buffer_offset = (y_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
+            const uint64_t y_buffer_offset = (y_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
             const uint64_t y_shader_offset = y_offset - y_buffer_offset;
 
-            const uint64_t d_buffer_offset = (d_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
+            const uint64_t d_buffer_offset = (d_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
             const uint64_t d_shader_offset = d_offset - d_buffer_offset;
 
             if (!y_non_contig && qy_needs_dequant) {
-                const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 };
+                const std::vector<uint32_t> pc = { (uint32_t)ne11, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(y_ne / 32) };
                 ggml_vk_sync_buffers(subctx);
-                ggml_vk_dispatch_pipeline(ctx, subctx, *to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1});
+                ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_1, { { d_Qy, qy_offset, qy_sz }, { d_Y, y_offset, y_sz } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)y_ne, 1, 1});
             }
 
             // compute
-            const std::array<int, 3> pc = { (int)ne00, (int)(y_shader_offset / ggml_type_size(src1->type)), (int)(d_shader_offset / ggml_type_size(dst->type))};
+            const std::array<uint32_t, 3> pc = { (uint32_t)ne00, (uint32_t)(y_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type))};
             ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, *dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
+            ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { { d_X, x_offset, x_sz }, { d_Y, y_buffer_offset, y_sz + y_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 3 * sizeof(int), &pc, { (uint32_t)ne01, 1, 1});
 
             if (dst->backend == GGML_BACKEND_TYPE_CPU) {
                 // copy dst to host
@@ -2680,7 +2778,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
 
     bool src1_uma = false;
 
-    if (ctx->device.lock()->uma) {
+    if (ctx->device->uma) {
         ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
         src1_uma = d_Qy != nullptr;
     }
@@ -2691,7 +2789,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     const uint64_t y_ne = ne10 * ne11 * ne12;
     const uint64_t d_ne = ne01 * ne11 * ne12;
 
-    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
+    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
     const uint64_t d_sz = sizeof(float) * d_ne;
 
@@ -2710,12 +2808,12 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     }
 
     // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, 1);
+    ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
 
-    const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
+    const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
 
-    const uint64_t d_buffer_offset = (d_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
+    const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
 
     if (load_y) {
@@ -2725,7 +2823,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     // compute
     const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
     ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
+    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
 
     if (dst->backend == GGML_BACKEND_TYPE_CPU) {
         // copy dst to host
@@ -2772,7 +2870,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
 
     bool src1_uma = false;
 
-    if (ctx->device.lock()->uma) {
+    if (ctx->device->uma) {
         ggml_vk_host_get(ctx, src1->data, d_Qy, qy_buf_offset);
         src1_uma = d_Qy != nullptr;
     }
@@ -2803,12 +2901,12 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     }
 
     // Allocate descriptor sets
-    ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, 1);
+    ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
 
-    const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
+    const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset;
 
-    const uint64_t d_buffer_offset = (d_buf_offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
+    const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
 
     if (load_y) {
@@ -2818,7 +2916,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     // compute
     const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
     ggml_vk_sync_buffers(subctx);
-    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
+    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { { d_Qx, qx_buf_offset, qx_sz }, { d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, { d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
 
     if (dst->backend == GGML_BACKEND_TYPE_CPU) {
         // copy dst to host
@@ -2856,6 +2954,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context * subctx,
     }
 }
 
+// static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context * subctx, const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+//
+// }
+
 static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     // guaranteed to be an integer due to the check in ggml_can_repeat
     const uint64_t ne0 = dst->ne[0];
@@ -2927,40 +3029,40 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx
 }
 
 
-static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op) {
+static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
     switch (op) {
     case GGML_OP_ADD:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_add_f32;
+            return ctx->device->pipeline_add_f32;
         }
         return nullptr;
     case GGML_OP_GET_ROWS:
         GGML_ASSERT(src1->type == GGML_TYPE_I32);
         if (dst->type == GGML_TYPE_F16) {
-            return &ctx->pipeline_get_rows[src0->type];
+            return ctx->device->pipeline_get_rows[src0->type];
         }
         if (dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_get_rows_f32[src0->type];
+            return ctx->device->pipeline_get_rows_f32[src0->type];
         }
         return nullptr;
     case GGML_OP_MUL:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_mul_f32;
+            return ctx->device->pipeline_mul_f32;
         }
         return nullptr;
     case GGML_OP_SCALE:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_scale_f32;
+            return ctx->device->pipeline_scale_f32;
         }
         return nullptr;
     case GGML_OP_SQR:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_sqr_f32;
+            return ctx->device->pipeline_sqr_f32;
         }
         return nullptr;
     case GGML_OP_CLAMP:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_clamp_f32;
+            return ctx->device->pipeline_clamp_f32;
         }
         return nullptr;
     case GGML_OP_CPY:
@@ -2969,29 +3071,29 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return ggml_vk_get_cpy_pipeline(ctx, src0->type, dst->type);
     case GGML_OP_NORM:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_norm_f32;
+            return ctx->device->pipeline_norm_f32;
         }
         return nullptr;
     case GGML_OP_RMS_NORM:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_rms_norm_f32;
+            return ctx->device->pipeline_rms_norm_f32;
         }
         return nullptr;
     case GGML_OP_UNARY:
         switch (ggml_get_unary_op(dst)) {
             case GGML_UNARY_OP_SILU:
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return &ctx->pipeline_silu_f32;
+                    return ctx->device->pipeline_silu_f32;
                 }
                 break;
             case GGML_UNARY_OP_GELU:
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return &ctx->pipeline_gelu_f32;
+                    return ctx->device->pipeline_gelu_f32;
                 }
                 break;
             case GGML_UNARY_OP_RELU:
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return &ctx->pipeline_relu_f32;
+                    return ctx->device->pipeline_relu_f32;
                 }
                 break;
             default:
@@ -3000,12 +3102,12 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
     case GGML_OP_DIAG_MASK_INF:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_diag_mask_inf_f32;
+            return ctx->device->pipeline_diag_mask_inf_f32;
         }
         return nullptr;
     case GGML_OP_SOFT_MAX:
-        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return &ctx->pipeline_soft_max_f32;
+        if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_soft_max_f32;
         }
         return nullptr;
     case GGML_OP_ROPE:
@@ -3020,21 +3122,26 @@ static vk_pipeline* ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
 
             if (is_neox) {
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return &ctx->pipeline_rope_neox_f32;
+                    return ctx->device->pipeline_rope_neox_f32;
                 }
                 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-                    return &ctx->pipeline_rope_neox_f16;
+                    return ctx->device->pipeline_rope_neox_f16;
                 }
             } else {
                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return &ctx->pipeline_rope_f32;
+                    return ctx->device->pipeline_rope_f32;
                 }
                 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-                    return &ctx->pipeline_rope_f16;
+                    return ctx->device->pipeline_rope_f16;
                 }
             }
             return nullptr;
         }
+    case GGML_OP_ARGSORT:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
+            return ctx->device->pipeline_argsort_f32;
+        }
+        return nullptr;
     default:
         return nullptr;
     }
@@ -3050,17 +3157,19 @@ static ggml_vk_func_t ggml_vk_op_get_func(ggml_op op) {
 }
 
 template<typename PC>
-static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, const PC&& pc) {
+static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) {
 #ifdef GGML_VULKAN_DEBUG
     std::cerr << "ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     if (src1 != nullptr) {
         std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
     }
+    if (src2 != nullptr) {
+        std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", backend=" << src2->backend << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
+    }
     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl;
 #endif
     GGML_ASSERT(!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)));  // NOLINT
     GGML_ASSERT(op == GGML_OP_CPY || ggml_vk_dim01_contiguous(src0));  // NOLINT
-    GGML_ASSERT(src1 == nullptr || ggml_vk_dim01_contiguous(src1));  // NOLINT
     GGML_ASSERT(dst->extra != nullptr);
     const uint64_t ne00 = src0->ne[0];
     const uint64_t ne01 = src0->ne[1];
@@ -3077,7 +3186,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
     const uint64_t nb2  = dst->nb[2];
     const uint64_t nb3  = dst->nb[3];
 
-    vk_pipeline * pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, dst, op);
+    const bool use_src2 = src2 != nullptr;
+    const uint64_t ne2 = use_src2 ? src2->ne[0] * src2->ne[1] : 0;
+
+    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
     ggml_vk_func_t op_func;
 
     if (pipeline == nullptr) {
@@ -3098,29 +3210,39 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
     ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra;
     ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+    ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
 
     vk_buffer d_X = nullptr;
     size_t x_buf_offset = 0;
     vk_buffer d_Y = nullptr;
     size_t y_buf_offset = 0;
+    vk_buffer d_Z = nullptr;
+    size_t z_buf_offset = 0;
 
     bool src0_uma = false;
     bool src1_uma = false;
+    bool src2_uma = false;
 
-    if (ctx->device.lock()->uma) {
+    if (ctx->device->uma) {
         ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset);
         src0_uma = d_X != nullptr;
         if (use_src1) {
             ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset);
             src1_uma = d_Y != nullptr;
         }
+        if (use_src2) {
+            ggml_vk_host_get(ctx, src1->data, d_Z, z_buf_offset);
+            src2_uma = d_Z != nullptr;
+        }
     }
 
     const bool transfer_src0 = src0->backend != GGML_BACKEND_TYPE_GPU && !src0_uma;
     const bool transfer_src1 = use_src1 && src1->backend != GGML_BACKEND_TYPE_GPU && !src1_uma;
+    const bool transfer_src2 = use_src2 && src2->backend != GGML_BACKEND_TYPE_GPU && !src2_uma;
 
-    uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment);
-    uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) : 0;
+    uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
+    uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
+    uint64_t z_sz = use_src2 ? ggml_vk_align_size(ggml_type_size(src2->type) * ne2, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0;
     uint64_t d_sz = ggml_type_size(dst->type) * ne0;
 
     vk_buffer d_D = extra->buffer_gpu.lock();
@@ -3131,7 +3253,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
     }
 
     GGML_ASSERT(d_D != nullptr);
-    uint64_t d_buf_offset = (extra->offset / ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
+    uint64_t d_buf_offset = (extra->offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
     GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY);  // NOLINT
     if (transfer_src0) {
         d_X = ctx->prealloc_qx;
@@ -3148,6 +3270,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
         GGML_ASSERT(d_Y != nullptr);
     }
 
+    GGML_ASSERT(!transfer_src2);
+    if (use_src2 && !src2_uma) {
+        d_Z = extra_src2->buffer_gpu.lock();
+        z_buf_offset = extra_src2->offset;
+        GGML_ASSERT(d_Z != nullptr);
+    }
+
     if (op == GGML_OP_CPY) {
         GGML_ASSERT(!transfer_src0);
         GGML_ASSERT(!transfer_src1);
@@ -3175,7 +3304,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
 
     // Single call if dimension 2 is contiguous
     if (op == GGML_OP_CPY || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, 1);
+        ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, 1);
 
         switch (dst->op) {
         case GGML_OP_NORM:
@@ -3204,16 +3333,30 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
             }
         }
 
-        if (!use_src1 && op == GGML_OP_SOFT_MAX) {
-            // Empty src1 is possible on soft_max, but the shader needs a buffer
+        if (op == GGML_OP_SOFT_MAX) {
+            // Empty src1 and src2 are possible on soft_max, but the shader needs buffers
+            vk_subbuffer subbuf_y;
+            if (use_src1) {
+                subbuf_y = { d_Y, y_buf_offset, y_sz };
+            } else {
+                subbuf_y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
+            }
+
+            vk_subbuffer subbuf_z;
+            if (use_src2) {
+                subbuf_z = { d_Z, z_buf_offset, z_sz };
+            } else {
+                subbuf_z = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
+            }
+
             ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { ctx->prealloc_y, 0, ctx->prealloc_y->size }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
         } else if (use_src1) {
             ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
         } else {
             ggml_vk_sync_buffers(subctx);
-            ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
+            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
         }
         if (dst->backend == GGML_BACKEND_TYPE_CPU && op == GGML_OP_CPY) {
             ggml_vk_d2h_tensor_2d(ctx, subctx, d_D, 0, dst);
@@ -3223,7 +3366,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
             ggml_vk_buffer_read_async(ctx, subctx, d_D, 0, d, d_sz);
         }
     } else {
-        ggml_pipeline_allocate_descriptor_sets(ctx, *pipeline, ne02 * ne03);
+        GGML_ASSERT(op != GGML_OP_SOFT_MAX);
+
+        ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03);
 
         switch (dst->op) {
         case GGML_OP_NORM:
@@ -3248,16 +3393,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
                 const uint32_t y_offset = y_sz * it_idx1;
                 const uint32_t d_offset = d_sz * it_idx0;
 
-                if (!use_src1 && op == GGML_OP_SOFT_MAX) {
-                    // Empty src1 is possible on soft_max, but the shader needs a buffer
-                    ggml_vk_sync_buffers(subctx);
-                    ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset, x_sz }, { ctx->prealloc_y, 0, ctx->prealloc_y->size }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
-                } else if (use_src1) {
+                if (use_src1) {
                     ggml_vk_sync_buffers(subctx);
-                    ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
+                    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_Y, y_buf_offset + y_offset, y_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
                 } else {
                     ggml_vk_sync_buffers(subctx);
-                    ggml_vk_dispatch_pipeline(ctx, subctx, *pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
+                    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset + x_offset, x_sz }, { d_D, d_buf_offset + d_offset, d_sz } }, sizeof(PC), &pc, elements);
                 }
                 if (dst->backend == GGML_BACKEND_TYPE_CPU) {
                     // copy dst to host
@@ -3269,69 +3410,141 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c
 }
 
 static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
 }
 
 static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
 }
 
 static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ADD, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t src1_type_size = ggml_type_size(src1->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    });
 }
 
 static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_MUL, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f });
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t src1_type_size = ggml_type_size(src1->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    });
 }
 
 static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SCALE, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        op_params[0], 0.0f
+    });
 }
 
 static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SQR, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    });
 }
 
 static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CLAMP, { (uint32_t)ggml_nelements(src0), 0, op_params[0], op_params[1] });
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        op_params[0], op_params[1],
+    });
 }
 
 static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
-    const int src0_type_size = ggml_type_size(src0->type);
-    const int dst_type_size = ggml_type_size(dst->type);
-    const uint32_t d_offset = (extra->offset % ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
-    ggml_vk_op_f32<vk_op_cpy_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CPY, {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+    const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
         (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size,
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         d_offset,
+        0.0f, 0.0f,
     });
 }
 
 static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
 }
 
 static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
 }
 
 static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
 }
 
 static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     int32_t * op_params = (int32_t *)dst->op_params;
-    ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
+    ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
 }
 
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_SOFT_MAX, { (uint32_t)src0->ne[0], (uint32_t)(src1 != nullptr ? ggml_nrows(src1) : 0), op_params[0], 0.0f });
+
+    float scale = op_params[0];
+    float max_bias = op_params[1];
+
+    const uint32_t ncols =   (uint32_t)src0->ne[0];
+    const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
+    const uint32_t nrows_y = (uint32_t)src0->ne[1];
+
+    const uint32_t n_head_kv   = nrows_x/nrows_y;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
+        ncols,
+        nrows_y,
+        src2 != nullptr ? (uint32_t)1 : (uint32_t)0,
+        scale, max_bias,
+        m0, m1,
+        n_head_log2,
+    });
 }
 
 static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3357,12 +3570,17 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con
     if (is_neox) {
         const float theta_scale = powf(freq_base, -2.0f/n_dims);
         const float inv_ndims = -1.0f / n_dims;
-        ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f, theta_scale, inv_ndims });
+        ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f, theta_scale, inv_ndims });
     } else {
-        ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f });
+        ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, corr_dims[0], corr_dims[1], 0.0f, 0.0f });
     }
 }
 
+static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    int32_t * op_params = (int32_t *)dst->op_params;
+    ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { (uint32_t)src0->ne[0], ((ggml_sort_order) op_params[0]) == GGML_SORT_ORDER_ASC });
+}
+
 static void ggml_vk_nop(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     // If backend is CPU, data from src0 has to be copied off the device
     if (dst->backend == GGML_BACKEND_TYPE_CPU) {
@@ -3414,43 +3632,43 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     const size_t y_ne = k * n * batch;
     const size_t d_ne = m * n * batch;
 
-    vk_pipeline p;
+    vk_pipeline p;
     std::string shname;
     if (shader_size == 0) {
         if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f32_aligned_s;
+            p = ctx->device->pipeline_matmul_f32->a_s;
             shname = "F32_ALIGNED_S";
         } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f16_f32_aligned_s;
+            p = ctx->device->pipeline_matmul_f16_f32->a_s;
             shname = "F16_F32_ALIGNED_S";
         } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f16_aligned_s;
+            p = ctx->device->pipeline_matmul_f16->a_s;
             shname = "F16_ALIGNED_S";
         } else {
             GGML_ASSERT(false);
         }
     } else if (shader_size == 1) {
         if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f32_aligned_m;
+            p = ctx->device->pipeline_matmul_f32->a_m;
             shname = "F32_ALIGNED_M";
         } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f16_f32_aligned_m;
+            p = ctx->device->pipeline_matmul_f16_f32->a_m;
             shname = "F16_F32_ALIGNED_M";
         } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f16_aligned_m;
+            p = ctx->device->pipeline_matmul_f16->a_m;
             shname = "F16_ALIGNED_M";
         } else {
             GGML_ASSERT(false);
         }
     } else if (shader_size == 2) {
         if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f32_aligned_l;
+            p = ctx->device->pipeline_matmul_f32->a_l;
             shname = "F32_ALIGNED_L";
         } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f16_f32_aligned_l;
+            p = ctx->device->pipeline_matmul_f16_f32->a_l;
             shname = "F16_F32_ALIGNED_L";
         } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            p = &ctx->pipeline_matmul_f16_aligned_l;
+            p = ctx->device->pipeline_matmul_f16->a_l;
             shname = "F16_ALIGNED_L";
         } else {
             GGML_ASSERT(false);
@@ -3464,43 +3682,43 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     if (k != kpad) {
         if (shader_size == 0) {
             if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f32_s;
+                p = ctx->device->pipeline_matmul_f32->s;
                 shname = "F32_S";
             } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f16_f32_s;
+                p = ctx->device->pipeline_matmul_f16_f32->s;
                 shname = "F16_F32_S";
             } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f16_s;
+                p = ctx->device->pipeline_matmul_f16->s;
                 shname = "F16_S";
             }
         } else if (shader_size == 1) {
             if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f32_m;
+                p = ctx->device->pipeline_matmul_f32->m;
                 shname = "F32_M";
             } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f16_f32_m;
+                p = ctx->device->pipeline_matmul_f16_f32->m;
                 shname = "F16_F32_M";
             } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f16_m;
+                p = ctx->device->pipeline_matmul_f16->m;
                 shname = "F16_M";
             }
         } else if (shader_size == 2) {
             if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f32_l;
+                p = ctx->device->pipeline_matmul_f32->l;
                 shname = "F32_L";
             } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f16_f32_l;
+                p = ctx->device->pipeline_matmul_f16_f32->l;
                 shname = "F16_F32_L";
             } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
-                p = &ctx->pipeline_matmul_f16_l;
+                p = ctx->device->pipeline_matmul_f16->l;
                 shname = "F16_L";
             }
         }
     }
 
-    ggml_pipeline_allocate_descriptor_sets(ctx, *p, num_it);
+    ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
-        ggml_pipeline_allocate_descriptor_sets(ctx, ctx->pipeline_matmul_split_k_reduce, num_it);
+        ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
         if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
             // Resize buffer
@@ -3530,9 +3748,11 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     }
     for (size_t i = 0; i < y_ne; i++) {
         if (std::is_same<float, Y_TYPE>()) {
-            y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+            // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+            y[i] = (i % k == i / k) ? 1.0f : 0.0f;
         } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
+            // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
+            y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
         } else {
             GGML_ASSERT(false);
         }
@@ -3541,17 +3761,17 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     ggml_vk_buffer_write(ctx, d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
     ggml_vk_buffer_write(ctx, d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
 
-    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
+    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
     for (size_t i = 0; i < num_it; i++) {
         ggml_vk_ctx_begin(ctx, subctx);
-        ggml_vk_matmul(ctx, subctx, *p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
+        ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
         ggml_vk_ctx_end(subctx);
     }
 
     auto begin = std::chrono::high_resolution_clock::now();
     ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
-    ctx->device.lock()->device.resetFences({ ctx->fence });
+    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
+    ctx->device->device.resetFences({ ctx->fence });
 
     auto end = std::chrono::high_resolution_clock::now();
     double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
@@ -3630,6 +3850,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
         std::cerr << "Actual result: " << std::endl << std::endl;
         ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+        std::cerr << std::endl;
+        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
         std::cerr << "Expected result: " << std::endl << std::endl;
         ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
 
@@ -3655,15 +3877,15 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
 
     free(d_chk);
 
-    ggml_vk_queue_cleanup(ctx, ctx->device.lock()->transfer_queue);
-    ggml_vk_queue_cleanup(ctx, ctx->device.lock()->compute_queue);
+    ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
+    ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
 
     ggml_vk_destroy_buffer(d_X);
     ggml_vk_destroy_buffer(d_Y);
     ggml_vk_destroy_buffer(d_D);
 
-    ggml_pipeline_cleanup(*p);
-    ggml_pipeline_cleanup(ctx->pipeline_matmul_split_k_reduce);
+    ggml_pipeline_cleanup(p);
+    ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce);
 
     free(x);
     free(y);
@@ -3736,7 +3958,7 @@ static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_
         data[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
     }
 
-    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
+    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
     ggml_vk_ctx_begin(ctx, subctx);
 
     vk_buffer buffer = ggml_vk_create_buffer_check(ctx, ggml_nbytes(tensor), vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -3745,8 +3967,8 @@ static void ggml_vk_test_h2d_nc(ggml_backend_vk_context * ctx, size_t ne0, size_
 
     ggml_vk_ctx_end(subctx);
     ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
-    ctx->device.lock()->device.resetFences({ ctx->fence });
+    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_h2d_nc waitForFences");
+    ctx->device->device.resetFences({ ctx->fence });
 
     ggml_vk_buffer_read(ctx, buffer, 0, result_data, ggml_nbytes(tensor));
 
@@ -3818,7 +4040,7 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
         x[i] = rand() / (float)RAND_MAX;
     }
 
-    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
+    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
     ggml_vk_ctx_begin(ctx, subctx);
 
     auto begin = std::chrono::high_resolution_clock::now();
@@ -3832,8 +4054,8 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
 
     ggml_vk_ctx_end(subctx);
     ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
-    ctx->device.lock()->device.resetFences({ ctx->fence });
+    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
+    ctx->device->device.resetFences({ ctx->fence });
 
     auto end = std::chrono::high_resolution_clock::now();
 
@@ -3847,8 +4069,8 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
 
     ggml_vk_ctx_end(subctx);
     ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
-    ctx->device.lock()->device.resetFences({ ctx->fence });
+    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_transfer waitForFences");
+    ctx->device->device.resetFences({ ctx->fence });
 
     for (auto& cpy : subctx->out_memcpys) {
         memcpy(cpy.dst, cpy.src, cpy.n);
@@ -3879,89 +4101,118 @@ static void ggml_vk_test_transfer(ggml_backend_vk_context * ctx, size_t ne, bool
     }
 }
 
-static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
-#ifdef GGML_VULKAN_DEBUG
-    std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
-#endif
-    const size_t x_sz = sizeof(float) * ne;
-    const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
-    const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
-    float * x = (float *) malloc(x_sz);
-    void * qx = malloc(qx_sz);
-    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    vk_buffer x_buf = ggml_vk_create_buffer_check(ctx, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
-    ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
-
-    for (size_t i = 0; i < ne; i++) {
-        x[i] = rand() / (float)RAND_MAX;
-    }
-
+static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
     std::vector<int64_t> hist_cur(1 << 4, 0);
 
-    vk_pipeline& p = ctx->pipeline_dequant[quant];
-
     switch(quant) {
+    case GGML_TYPE_F32:
+        memcpy(to, from, sizeof(float) * ne);
+        break;
     case GGML_TYPE_Q4_0:
-        ggml_quantize_q4_0(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q4_0(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q4_1:
-        ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q4_1(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q5_0:
-        ggml_quantize_q5_0(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q5_0(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q5_1:
-        ggml_quantize_q4_1(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q5_1(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q8_0:
-        ggml_quantize_q8_0(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q8_0(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q2_K:
-        ggml_quantize_q2_K(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q2_K(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q3_K:
-        ggml_quantize_q3_K(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q3_K(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q4_K:
-        ggml_quantize_q4_K(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q4_K(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q5_K:
-        ggml_quantize_q5_K(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q5_K(from, to, ne, ne, hist_cur.data());
         break;
     case GGML_TYPE_Q6_K:
-        ggml_quantize_q6_K(x, qx, ne, ne, hist_cur.data());
+        ggml_quantize_q6_K(from, to, ne, ne, hist_cur.data());
         break;
     default:
         GGML_ASSERT(false);
     }
+}
+
+static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
+#ifdef GGML_VULKAN_DEBUG
+    std::cerr << "ggml_vk_test_dequant(" << ne << ")" << std::endl;
+#endif
+    const size_t x_sz = sizeof(float) * ne;
+    const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
+    const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
+    float * x = (float *) malloc(x_sz);
+    void * qx = malloc(qx_sz);
+    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+    vk_buffer x_buf = ggml_vk_create_buffer_check(ctx, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
+    ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
+
+    for (size_t i = 0; i < ne; i++) {
+        x[i] = rand() / (float)RAND_MAX;
+    }
+
+    vk_pipeline p = ctx->device->pipeline_dequant[quant];
+
+    ggml_vk_quantize_data(x, qx, ne, quant);
 
     ggml_pipeline_allocate_descriptor_sets(ctx, p, 1);
 
     ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
 
-    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
+    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
     ggml_vk_ctx_begin(ctx, subctx);
-    const std::vector<int> pc = { 1, (int)ne, (int)ne, (int)ne };
+    const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
     ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1});
     ggml_vk_ctx_end(subctx);
 
     auto begin = std::chrono::high_resolution_clock::now();
 
     ggml_vk_submit(subctx, ctx->fence);
-    VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
-    ctx->device.lock()->device.resetFences({ ctx->fence });
+    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
+    ctx->device->device.resetFences({ ctx->fence });
 
     auto end = std::chrono::high_resolution_clock::now();
 
     double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
     ggml_vk_buffer_read(ctx, x_buf, 0, x_chk, x_sz_f16);
 
+    int first_err = -1;
+
     double avg_err = 0.0;
     for (size_t i = 0; i < ne; i++) {
-        avg_err += std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
+        double error = std::fabs(x[i] - ggml_fp16_to_fp32(x_chk[i]));
+        avg_err += error;
+
+        if (first_err < 0 && error > 0.05) {
+            first_err = i;
+        }
     }
 
-    std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err / ne << std::endl;
+    avg_err /= ne;
+
+    std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
+
+    if (avg_err > 0.1) {
+        std::cerr << "first_error = " << first_err << std::endl;
+        std::cerr << "Actual result: " << std::endl << std::endl;
+        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
+            std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
+        }
+        std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
+        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
+            std::cerr << x[i] << ", ";
+        }
+        std::cerr << std::endl;
+    }
 
     ggml_vk_destroy_buffer(x_buf);
     ggml_vk_destroy_buffer(qx_buf);
@@ -3970,6 +4221,190 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
     free(qx);
     free(x_chk);
 }
+
+static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
+#ifdef GGML_VULKAN_DEBUG
+    std::cerr << "ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")" << std::endl;
+#endif
+    const size_t x_ne = m * k * batch;
+    const size_t y_ne = k * n * batch;
+    const size_t d_ne = m * n * batch;
+
+    vk_pipeline p;
+    std::string shname;
+    if (shader_size == 0) {
+        p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
+        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
+    } else if (shader_size == 1) {
+        p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
+        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
+    } else if (shader_size == 2) {
+        p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
+        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
+    } else {
+        GGML_ASSERT(0);
+    }
+
+    const size_t kpad = ggml_vk_align_size(k, p->align);
+
+    if (k != kpad) {
+        if (shader_size == 0) {
+            p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
+            shname = std::string(ggml_type_name(quant)) + "_S";
+        } else if (shader_size == 1) {
+            p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
+            shname = std::string(ggml_type_name(quant)) + "_M";
+        } else if (shader_size == 2) {
+            p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
+            shname = std::string(ggml_type_name(quant)) + "_L";
+        } else {
+            GGML_ASSERT(0);
+        }
+    }
+
+    const size_t x_sz = sizeof(float) * x_ne;
+    const size_t y_sz = sizeof(float) * y_ne;
+    const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
+    const size_t d_sz = sizeof(float) * d_ne;
+    float * x = (float *) malloc(x_sz);
+    float * y = (float *) malloc(y_sz);
+    void * qx = malloc(qx_sz);
+    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+    vk_buffer y_buf = ggml_vk_create_buffer_check(ctx, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+    vk_buffer d_buf = ggml_vk_create_buffer_check(ctx, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+    float * d = (float *) malloc(d_sz);
+    float * d_chk = (float *) malloc(d_sz);
+
+    for (size_t i = 0; i < x_ne; i++) {
+        x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+    }
+
+    ggml_vk_quantize_data(x, qx, x_ne, quant);
+
+    for (size_t i = 0; i < y_ne; i++) {
+        // y[i] = rand() / (float)RAND_MAX;
+        y[i] = (i % k == i / k) ? 1.0f : 0.0f;
+    }
+
+    ggml_pipeline_allocate_descriptor_sets(ctx, p, num_it);
+    if (split_k > 1) {
+        ggml_pipeline_allocate_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
+
+        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
+            // Resize buffer
+            if (ctx->prealloc_split_k != nullptr) {
+                ggml_vk_destroy_buffer(ctx->prealloc_split_k);
+            }
+            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
+        }
+    }
+
+    ggml_vk_buffer_write(ctx, qx_buf, 0, qx, qx_sz);
+    ggml_vk_buffer_write(ctx, y_buf, 0, y, y_sz);
+
+    vk_context * subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+    for (size_t i = 0; i < num_it; i++) {
+        ggml_vk_ctx_begin(ctx, subctx);
+        ggml_vk_matmul(ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, split_k, batch, batch, batch, 1, 1, k*m, k*n, m*n);
+        ggml_vk_ctx_end(subctx);
+    }
+
+    auto begin = std::chrono::high_resolution_clock::now();
+
+    ggml_vk_submit(subctx, ctx->fence);
+    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
+    ctx->device->device.resetFences({ ctx->fence });
+
+    auto end = std::chrono::high_resolution_clock::now();
+
+    double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
+    ggml_vk_buffer_read(ctx, d_buf, 0, d, d_sz);
+
+    ggml_init_params iparams = {
+        /*.mem_size   =*/ 1024*1024*1024,
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+
+    ggml_context * ggml_ctx = ggml_init(iparams);
+
+    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
+    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
+    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
+
+    src0_ggml->data = qx;
+    src1_ggml->data = y;
+    tensor_ggml->data = d_chk;
+
+    ctx->disable = true;
+
+    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
+    ggml_build_forward_expand(cgraph, tensor_ggml);
+
+    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
+
+    ctx->disable = false;
+
+    ggml_free(ggml_ctx);
+
+    double avg_err = 0.0;
+    int first_err_n = -1;
+    int first_err_m = -1;
+    int first_err_b = -1;
+
+    for (size_t i = 0; i < m*n*batch; i++) {
+        double err = std::fabs(d[i] - d_chk[i]);
+        avg_err += err;
+
+        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
+            first_err_b = i / (m * n);
+            first_err_n = (i % (m * n)) / m;
+            first_err_m = (i % (m * n)) % m;
+        }
+    }
+
+    avg_err /= m * n;
+
+    std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl;
+
+    if (avg_err > 0.1 || std::isnan(avg_err)) {
+        std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
+        std::cerr << "Actual result: " << std::endl << std::endl;
+        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+        std::cerr << std::endl;
+        std::cerr << "Expected result: " << std::endl << std::endl;
+        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+        if (split_k > 1) {
+            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
+            ggml_vk_buffer_read(ctx, ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
+
+            std::cerr << "d_buf0: " << std::endl << std::endl;
+            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+            std::cerr << "d_buf1: " << std::endl << std::endl;
+            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+            std::cerr << "d_buf2: " << std::endl << std::endl;
+            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+            std::cerr << "d_buf3: " << std::endl << std::endl;
+            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
+
+            free(split_k_buf);
+        }
+    }
+
+    ggml_vk_destroy_buffer(qx_buf);
+    ggml_vk_destroy_buffer(y_buf);
+    ggml_vk_destroy_buffer(d_buf);
+
+    free(x);
+    free(qx);
+    free(y);
+    free(d);
+    free(d_chk);
+}
 #endif
 
 static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) {
@@ -3982,18 +4417,8 @@ static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor)
     return extra;
 }
 
-static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph * graph) {
-    GGML_ASSERT(node != nullptr);
-
-    for (int i = graph->n_nodes - 1; i >= 0; i--) {
-        for (int j = 0; j < GGML_MAX_SRC; j++) {
-            if (graph->nodes[i]->src[j] == node) {
-                return graph->nodes[i];
-            }
-        }
-    }
-
-    return nullptr;
+static bool ggml_vk_cpu_assist_op(const ggml_tensor * node) {
+    return node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID;
 }
 
 static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){
@@ -4004,7 +4429,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
         || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
         || (node->src[1] != nullptr && (node->src[1]->backend == GGML_BACKEND_TYPE_GPU));
 
-    if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT)) {
+    if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node))) {
         return;
     }
 
@@ -4035,7 +4460,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
     const bool f16_f32_kernel = use_src1 && src1->type == GGML_TYPE_F32;
 
     int split_k;
-    if (node->op == GGML_OP_MUL_MAT) {
+    if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
         split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
     } else {
         split_k = 1;
@@ -4044,11 +4469,11 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
     const uint32_t y_ne = ne10 * ne11;
     const uint32_t d_ne = ne20 * ne21;
 
-    const uint64_t qx_sz = use_src0 ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
-    const uint64_t qy_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
-    const uint64_t x_sz = use_src0 ? ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
-    const uint64_t y_sz = use_src1 ? ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
-    uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
+    const uint64_t qx_sz = use_src0 ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
+    const uint64_t qy_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
+    const uint64_t x_sz = use_src0 ? ggml_vk_align_size(sizeof(ggml_fp16_t) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne02 * ne03 : 0;
+    const uint64_t y_sz = use_src1 ? ggml_vk_align_size(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne12 * ne13 : 0;
+    uint64_t d_sz = ggml_vk_align_size(ggml_type_size(node->type) * d_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ne22 * ne23;
     const uint64_t split_k_size = split_k > 1 ? d_sz * 4 : 0;
 
     if (extra->buffer_gpu.expired()) {
@@ -4076,6 +4501,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
     case GGML_OP_DIAG_MASK_INF:
     case GGML_OP_SOFT_MAX:
     case GGML_OP_ROPE:
+    case GGML_OP_ARGSORT:
         break;
     case GGML_OP_UNARY:
         switch (ggml_get_unary_op(node)) {
@@ -4088,6 +4514,7 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
         }
         break;
     case GGML_OP_MUL_MAT:
+    case GGML_OP_MUL_MAT_ID:
         if (ctx->prealloc_size_qx < qx_sz) {
             ctx->prealloc_size_qx = qx_sz;
         }
@@ -4121,21 +4548,66 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
 #endif
 #if defined(GGML_VULKAN_RUN_TESTS)
     ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul,
-        vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
+        vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
         vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
     ggml_vk_test_transfer(ctx, 8192 * 1000, false);
     ggml_vk_test_transfer(ctx, 8192 * 1000, true);
 
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_0);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_1);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_0);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_1);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q8_0);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q2_K);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q3_K);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q4_K);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q5_K);
-    ggml_vk_test_dequant(ctx, 2560 * 7680, GGML_TYPE_Q6_K);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
+    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
+
+    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
+    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
+    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
+    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
+    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
+    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
+
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
+
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
+
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
+
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
+
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
+    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
+
+    std::cerr << std::endl;
 
     const std::vector<size_t> vals {
         8, 8, 8,
@@ -4225,7 +4697,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
         || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_TYPE_GPU || node->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
         || (node->src[1] != nullptr && node->src[1]->backend == GGML_BACKEND_TYPE_GPU);
 
-    if (ctx->disable || (!any_on_device && node->op != GGML_OP_MUL_MAT) || (node->op == GGML_OP_MUL_MAT && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
+    if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(node)) || (ggml_vk_cpu_assist_op(node) && !any_on_device && !ggml_vk_can_mul_mat(node->src[0], node->src[1], node))) {
         return;
     }
 
@@ -4237,6 +4709,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 
     const ggml_tensor * src0 = node->src[0];
     const ggml_tensor * src1 = node->src[1];
+    const ggml_tensor * src2 = node->src[2];
 
     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra;
 
@@ -4271,7 +4744,9 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_SOFT_MAX:
     case GGML_OP_ROPE:
     case GGML_OP_MUL_MAT:
+    case GGML_OP_MUL_MAT_ID:
     case GGML_OP_NONE:
+    case GGML_OP_ARGSORT:
         break;
     default:
         if (any_on_device) {
@@ -4282,7 +4757,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     }
 
     if (ctx->compute_ctx == nullptr) {
-        ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->compute_queue);
+        ctx->compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
         ggml_vk_ctx_begin(ctx, ctx->compute_ctx);
     }
 
@@ -4353,16 +4828,25 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 
         break;
     case GGML_OP_SOFT_MAX:
-        ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node);
+        ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, src2, node);
 
         break;
     case GGML_OP_ROPE:
         ggml_vk_rope(ctx, ctx->compute_ctx, src0, src1, node);
 
+        break;
+    case GGML_OP_ARGSORT:
+        ggml_vk_argsort(ctx, ctx->compute_ctx, src0, node);
         break;
     case GGML_OP_MUL_MAT:
         ggml_vk_mul_mat(ctx, ctx->compute_ctx, src0, src1, node);
 
+        break;
+    case GGML_OP_MUL_MAT_ID:
+        //ggml_vk_mul_mat_id(ctx, ctx->compute_ctx, src0, src1, node);
+        std::cerr << "ggml_vulkan: GGML_OP_MUL_MAT_ID not implemented yet." << std::endl;
+        GGML_ASSERT(false);
+
         break;
     default:
         return;
@@ -4389,7 +4873,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
         || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU || tensor->src[0]->backend == GGML_BACKEND_TYPE_GPU_SPLIT))
         || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_TYPE_GPU);
 
-    if (ctx->disable || (!any_on_device && tensor->op != GGML_OP_MUL_MAT)) {
+    if (ctx->disable || (!any_on_device && !ggml_vk_cpu_assist_op(tensor))) {
         return false;
     }
 
@@ -4415,6 +4899,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
     case GGML_OP_PERMUTE:
     case GGML_OP_TRANSPOSE:
     case GGML_OP_NONE:
+    case GGML_OP_ARGSORT:
         extra = (ggml_tensor_extra_gpu *) tensor->extra;
 
         break;
@@ -4430,6 +4915,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
         }
         break;
     case GGML_OP_MUL_MAT:
+    case GGML_OP_MUL_MAT_ID:
         if (!any_on_device && !ggml_vk_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
             return false;
         }
@@ -4475,8 +4961,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_compute_
     }
 
     if (tensor == subctx.exit_tensor) {
-        VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
-        ctx->device.lock()->device.resetFences({ ctx->fence });
+        VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
+        ctx->device->device.resetFences({ ctx->fence });
 
         // Do staging buffer copies
         for (auto& cpy : subctx.out_memcpys) {
@@ -4504,20 +4990,25 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
     }
     ctx->gc.temp_buffers.clear();
 
-    for (auto * pipeline : ctx->gc.pipelines) {
-        ggml_pipeline_cleanup(*pipeline);
+    for (auto& pipeline : ctx->device->pipelines) {
+        if (pipeline.expired()) {
+            continue;
+        }
+
+        vk_pipeline pl = pipeline.lock();
+        ggml_pipeline_cleanup(pl);
     }
 
-    ggml_vk_queue_cleanup(ctx, ctx->device.lock()->compute_queue);
-    ggml_vk_queue_cleanup(ctx, ctx->device.lock()->transfer_queue);
+    ggml_vk_queue_cleanup(ctx, ctx->device->compute_queue);
+    ggml_vk_queue_cleanup(ctx, ctx->device->transfer_queue);
 
     for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
-        ctx->device.lock()->device.destroySemaphore({ ctx->gc.semaphores[i].s });
+        ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
     }
     ctx->gc.semaphores.clear();
 
     for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
-        ctx->device.lock()->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
+        ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
     }
     ctx->gc.tl_semaphores.clear();
     ctx->semaphore_idx = 0;
@@ -4525,7 +5016,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
     ctx->event_idx = 0;
 
     for (auto& event : ctx->gc.events) {
-        ctx->device.lock()->device.resetEvent(event);
+        ctx->device->device.resetEvent(event);
     }
 
     ctx->staging_offset = 0;
@@ -4562,21 +5053,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
     ctx->staging_size = 0;
 
     for (auto& event : ctx->gc.events) {
-        ctx->device.lock()->device.destroyEvent(event);
+        ctx->device->device.destroyEvent(event);
     }
     ctx->gc.events.clear();
 
-    for (auto* pipeline : ctx->gc.pipelines) {
-        ggml_vk_destroy_pipeline(ctx, pipeline);
-    }
-    ctx->gc.pipelines.clear();
-
-    ctx->device.lock()->device.destroyFence(ctx->fence);
-
-    ctx->device.lock()->device.destroyCommandPool(ctx->device.lock()->compute_queue.pool);
-    if (!ctx->device.lock()->single_queue) {
-        ctx->device.lock()->device.destroyCommandPool(ctx->device.lock()->transfer_queue.pool);
-    }
+    ctx->device->device.destroyFence(ctx->fence);
 }
 
 GGML_CALL static int ggml_vk_get_device_count() {
@@ -4787,7 +5268,6 @@ GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t bu
 
 GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
     if (ggml_backend_buffer_is_vk(src->buffer)) {
-        ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
         ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra;
         ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
 
@@ -4799,6 +5279,8 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu
         return true;
     }
     return false;
+
+    UNUSED(buffer);
 }
 
 GGML_CALL static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -4845,12 +5327,12 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(
 
 GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
-    return ctx->ctx->device.lock()->properties.limits.minStorageBufferOffsetAlignment;
+    return ctx->ctx->device->properties.limits.minStorageBufferOffsetAlignment;
 }
 
 GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
     ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
-    return ctx->ctx->device.lock()->max_memory_allocation_size;
+    return ctx->ctx->device->max_memory_allocation_size;
 }
 
 GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
@@ -4936,7 +5418,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_bu
 }
 
 GGML_CALL static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return vk_instance.contexts[0].device.lock()->properties.limits.minMemoryMapAlignment;
+    return vk_instance.contexts[0].device->properties.limits.minMemoryMapAlignment;
 
     UNUSED(buft);
 }
@@ -4981,8 +5463,7 @@ GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) {
 
     ggml_vk_cleanup(ctx);
 
-    // Release device
-    vk_instance.devices[ctx->idx].reset();
+    ctx->device.reset();
     ctx->initialized = false;
 
     vk_instance.initialized[idx] = false;
@@ -5011,7 +5492,7 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g
 
     if (ctx->transfer_ctx == nullptr) {
         // Initialize new transfer context
-        ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
+        ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
         ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
     }
 
@@ -5032,7 +5513,7 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c
 
     if (ctx->transfer_ctx == nullptr) {
         // Initialize new transfer context
-        ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
+        ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
         ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
     }
 
@@ -5052,7 +5533,7 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c
 
         if (ctx->transfer_ctx == nullptr) {
             // Initialize new transfer context
-            ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device.lock()->transfer_queue);
+            ctx->transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue);
             ggml_vk_ctx_begin(ctx, ctx->transfer_ctx);
         }
 
@@ -5082,8 +5563,8 @@ GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
     }
 
     ggml_vk_submit(ctx->transfer_ctx, ctx->fence);
-    VK_CHECK(ctx->device.lock()->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
-    ctx->device.lock()->device.resetFences({ ctx->fence });
+    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
+    ctx->device->device.resetFences({ ctx->fence });
 
     for (auto& cpy : ctx->transfer_ctx->out_memcpys) {
         memcpy(cpy.dst, cpy.src, cpy.n);
@@ -5153,6 +5634,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
             }
             break;
         case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
             {
                 struct ggml_tensor * a;
                 struct ggml_tensor * b;
@@ -5226,6 +5708,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
+        case GGML_OP_ARGSORT:
             return true;
         default:
             return false;
@@ -5505,6 +5988,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
 
     ggml_tensor * src0 = tensor->src[0];
     ggml_tensor * src1 = tensor->src[1];
+    ggml_tensor * src2 = tensor->src[2];
 
     struct ggml_init_params iparams = {
         /*.mem_size   =*/ 1024*1024*1024,
@@ -5516,13 +6000,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
 
     struct ggml_tensor * src0_clone = nullptr;
     struct ggml_tensor * src1_clone = nullptr;
+    struct ggml_tensor * src2_clone = nullptr;
     struct ggml_tensor * tensor_clone = nullptr;
 
     size_t src0_size;
     size_t src1_size;
+    size_t src2_size;
 
     void * src0_buffer;
     void * src1_buffer;
+    void * src2_buffer;
 
     if (src0 != nullptr) {
         src0_clone = ggml_dup_tensor(ggml_ctx, src0);
@@ -5536,12 +6023,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
             memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
         } else if (src0->backend == GGML_BACKEND_TYPE_GPU) {
             ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra;
+            vk_buffer buffer_gpu = extra->buffer_gpu.lock();
             uint64_t offset = extra->offset;
             if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) {
                 for (int i3 = 0; i3 < src0->ne[3]; i3++) {
                     for (int i2 = 0; i2 < src0->ne[2]; i2++) {
                         const int idx = i3*src0->ne[2] + i2;
-                        vk_buffer buffer_gpu = extra->buffer_gpu.lock();
                         ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
                     }
                 }
@@ -5552,7 +6039,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
                     src0_clone->nb[i] = src0_clone->nb[i - 1]*src0_clone->ne[i - 1];
                 }
             } else {
-                vk_buffer buffer_gpu = extra->buffer_gpu.lock();
                 if (offset + src0_size >= buffer_gpu->size) {
                     src0_size = buffer_gpu->size - offset;
                 }
@@ -5581,12 +6067,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
             memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
         } else if (src1->backend == GGML_BACKEND_TYPE_GPU) {
             ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra;
+            vk_buffer buffer_gpu = extra->buffer_gpu.lock();
             uint64_t offset = extra->offset;
             if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) {
                 for (int i3 = 0; i3 < src1->ne[3]; i3++) {
                     for (int i2 = 0; i2 < src1->ne[2]; i2++) {
                         const int idx = i3*src1->ne[2] + i2;
-                        vk_buffer buffer_gpu = extra->buffer_gpu.lock();
                         ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
                     }
                 }
@@ -5597,7 +6083,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
                     src1_clone->nb[i] = src1_clone->nb[i - 1]*src1_clone->ne[i - 1];
                 }
             } else {
-                vk_buffer buffer_gpu = extra->buffer_gpu.lock();
                 if (offset + src1_size >= buffer_gpu->size) {
                     src1_size = buffer_gpu->size - offset;
                 }
@@ -5630,6 +6115,66 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
 
         ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone);
     }
+    if (src2 != nullptr) {
+        src2_clone = ggml_dup_tensor(ggml_ctx, src2);
+
+        src2_size = ggml_nbytes(src2);
+
+        src2_buffer = malloc(src2_size);
+        src2_clone->data = src2_buffer;
+        if (src2->backend == GGML_BACKEND_TYPE_CPU) {
+            memcpy(src2_clone->data, src2->data, src2_size);
+            memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
+        } else if (src2->backend == GGML_BACKEND_TYPE_GPU) {
+            ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra;
+            vk_buffer buf = extra->buffer_gpu.lock();
+            uint64_t offset = extra->offset;
+            if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) {
+                for (int i3 = 0; i3 < src2->ne[3]; i3++) {
+                    for (int i2 = 0; i2 < src2->ne[2]; i2++) {
+                        const int idx = i3*src2->ne[2] + i2;
+                        ggml_vk_buffer_read(ctx, buf, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
+                    }
+                }
+
+                src2_clone->nb[0] = src2->nb[0];
+                src2_clone->nb[1] = src2->nb[1];
+                for (int i = 2; i < GGML_MAX_DIMS; i++) {
+                    src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1];
+                }
+            } else {
+                if (offset + src2_size >= buf->size) {
+                    src2_size = buf->size - offset;
+                }
+                ggml_vk_buffer_read(ctx, buf, offset, src2_clone->data, src2_size);
+                memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
+            }
+        } else {
+            GGML_ASSERT(false);
+        }
+
+        if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
+            ggml_vk_print_tensor(ctx, src2, "src2");
+            std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl;
+            std::cerr << "src2_clone=" << tensor << " src2_clone->backend: " << src2_clone->backend << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl;
+            if (src2->src[0] != nullptr) {
+                std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " backend=" << src2->src[0]->backend << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl;
+            }
+            if (src2->src[1] != nullptr) {
+                std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " backend=" << src2->src[1]->backend << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl;
+            }
+            std::cerr << std::endl << "Result:" << std::endl;
+            ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0);
+            std::cerr << std::endl;
+            std::cerr << std::endl << "Result:" << std::endl;
+            ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0);
+            std::cerr << std::endl;
+            std::vector<const ggml_tensor *> done;
+            ggml_vk_print_graph_origin(src2_clone, done);
+        }
+
+        ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src2", src2_clone);
+    }
 
     if (tensor->op == GGML_OP_MUL_MAT) {
         tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
@@ -5648,7 +6193,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
     } else if (tensor->op == GGML_OP_RMS_NORM) {
         tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_SOFT_MAX) {
+        if (src1 != nullptr) {
+            tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+        } else {
             tensor_clone = ggml_soft_max(ggml_ctx, src0_clone);
+        }
     } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
         tensor_clone = ggml_diag_mask_inf(ggml_ctx, src0_clone, *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_ROPE) {
@@ -5728,6 +6277,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_
     if (src1 != nullptr) {
         free(src1_buffer);
     }
+    if (src2 != nullptr) {
+        free(src1_buffer);
+    }
 
     ggml_free(ggml_ctx);
 }
index 9645126b4f4a56087c655bb76dd29f3a69bb24ff..e4317c3e03481c7a93f98a01916b6092e1afe55d 100644 (file)
@@ -10,6 +10,7 @@ extern "C" {
 #define GGML_VK_NAME "Vulkan"
 #define GGML_VK_MAX_DEVICES 16
 
+GGML_API void ggml_vk_instance_init(void);
 GGML_API void ggml_vk_init_cpu_assist(void);
 
 GGML_API void ggml_vk_preallocate_buffers_graph_cpu_assist(struct ggml_tensor * node);